File size: 22,317 Bytes
63ce055
 
 
 
 
 
18187e6
63ce055
 
 
 
 
 
 
 
 
20f2285
63ce055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f61308a
 
 
 
63ce055
 
 
 
 
 
 
 
 
 
20f2285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63ce055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ed0398
63ce055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452271a
63ce055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20f2285
 
 
 
 
 
 
 
63ce055
20f2285
 
63ce055
 
 
 
20f2285
63ce055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18187e6
63ce055
 
 
 
 
 
18187e6
63ce055
 
 
 
 
 
 
 
572bdd0
1c2a62c
63ce055
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
"""
Cognitive Proxy - Brain-Steered Language Model
Hugging Face Spaces deployment
Author: Sandro Andric
"""

import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import pickle
import os
from pathlib import Path
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, AutoModelForCausalLM
import plotly.graph_objects as go
import plotly.express as px
import spaces  # For ZeroGPU on Hugging Face

# --- CONFIG ---
import os
from pathlib import Path

# Get the directory of this script
SCRIPT_DIR = Path(__file__).parent if __file__ else Path.cwd()

# Try multiple possible locations for the model files
if (SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl").exists():
    ATLAS_PATH = str(SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl")
    ADAPTER_PATH = str(SCRIPT_DIR / "results" / "tinyllama_adapter_direct.pt")
elif (SCRIPT_DIR / "final_atlas_256_vocab.pkl").exists():
    ATLAS_PATH = str(SCRIPT_DIR / "final_atlas_256_vocab.pkl")
    ADAPTER_PATH = str(SCRIPT_DIR / "tinyllama_adapter_direct.pt")
else:
    # Fallback to expected location
    ATLAS_PATH = "results/final_atlas_256_vocab.pkl"
    ADAPTER_PATH = "results/tinyllama_adapter_direct.pt"

print(f"Atlas path: {ATLAS_PATH}")
print(f"Adapter path: {ADAPTER_PATH}")

MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# --- ADAPTER CLASS ---
class TinyLlamaAdapterDirect(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=65536):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Linear(hidden_dim // 2, output_dim),
        )

    def forward(self, x):
        return self.net(x)

# Global system cache
system = None

def load_system():
    global system
    if system is not None:
        return system

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    tokenizer.pad_token = tokenizer.eos_token

    # Use float32 for CPU, float16 for GPU
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    try:
        # Try new parameter name first
        model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=dtype).to(device)
    except TypeError:
        # Fall back to old parameter name
        model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device)
    model.eval()

    adapter = TinyLlamaAdapterDirect().to(device).to(dtype)
    if os.path.exists(ADAPTER_PATH):
        adapter.load_state_dict(torch.load(ADAPTER_PATH, map_location=device, weights_only=True))
    adapter.eval()

    if os.path.exists(ATLAS_PATH):
        print(f"Loading atlas from {ATLAS_PATH}")
        with open(ATLAS_PATH, 'rb') as f:
            data = pickle.load(f)
            if isinstance(data, dict):
                print(f"Atlas data keys: {list(data.keys())[:5]}")
                if 'means' in data:
                    atlas = data['means']
                    print(f"Using 'means' key, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items")
                else:
                    atlas = data
                    print(f"Using data directly, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items")
            else:
                atlas = data
                print(f"Atlas is not a dict, type: {type(data)}")
    else:
        print(f"Atlas file not found at {ATLAS_PATH}")
        atlas = {}

    # Ensure atlas is valid
    if not atlas or not isinstance(atlas, dict):
        print(f"Warning: Atlas is empty or invalid, using fallback")
        atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)}

    words = list(atlas.keys())
    print(f"Loaded atlas with {len(words)} words")
    if len(words) < 2:
        print(f"Warning: Not enough words in atlas ({len(words)}), using fallback")
        atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)}
        words = list(atlas.keys())

    # Handle both 256x256 and flat arrays
    first_val = np.array(atlas[words[0]])
    if first_val.shape == (256, 256):
        plv_matrix = np.array([np.array(atlas[w]).flatten() for w in words])
    else:
        plv_matrix = np.array([np.array(atlas[w]) for w in words])

    # Ensure matrix is 2D
    if len(plv_matrix.shape) == 1 or plv_matrix.shape[0] < 2:
        print(f"Warning: Invalid PLV matrix shape {plv_matrix.shape}, using fallback")
        plv_matrix = np.random.randn(10, 65536)

    pca = PCA(n_components=min(10, plv_matrix.shape[0] - 1))
    pca.fit(plv_matrix)
    pc1_axis = pca.components_[0]
    pc1_axis = pc1_axis / np.linalg.norm(pc1_axis)
    global_mean = plv_matrix.mean(axis=0)

    system = {
        'model': model,
        'tokenizer': tokenizer,
        'adapter': adapter,
        'axis': torch.tensor(pc1_axis, dtype=torch.float32).to(device),
        'global_mean': torch.tensor(global_mean, dtype=torch.float32).to(device),
        'device': device
    }
    return system

@spaces.GPU(duration=60)
def generate_variants(prompt, scenario, max_tokens):
    """Generate all three variants"""
    sys = load_system()

    if scenario == "Educational":
        prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n"
        alpha_strength = 5.0
    elif scenario == "Technical writing":
        prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n"
        alpha_strength = 5.0
    else:
        prompt_formatted = prompt
        alpha_strength = 3.0

    outputs = []
    for alpha in [-alpha_strength, 0, alpha_strength]:
        inputs = sys['tokenizer'](prompt_formatted, return_tensors='pt').to(sys['device'])
        generated_ids = inputs.input_ids.clone()

        for _ in range(max_tokens):
            outputs_model = sys['model'](generated_ids, output_hidden_states=True)
            hidden = outputs_model.hidden_states[-1][:, -1, :]

            # Ensure proper dtype for adapter
            adapter_dtype = next(sys['adapter'].parameters()).dtype
            hidden = hidden.to(adapter_dtype)

            if alpha != 0:
                hidden = hidden.detach().requires_grad_(True)
                plv_pred = sys['adapter'](hidden)
                score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype))
                grad = torch.autograd.grad(score, hidden, retain_graph=False)[0]
                grad = grad / (grad.norm() + 1e-8)
                hidden = hidden.detach() + alpha * grad.detach()

            with torch.no_grad():
                logits = sys['model'].lm_head(sys['model'].model.norm(hidden))
                probs = torch.softmax(logits / 0.8, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                generated_ids = torch.cat([generated_ids, next_token], dim=-1)
                if next_token.item() == sys['tokenizer'].eos_token_id:
                    break

        text = sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True)
        if "<|assistant|>" in text:
            text = text.split("<|assistant|>")[-1].strip()
        outputs.append(text)

    return outputs[0], outputs[1], outputs[2]

@spaces.GPU(duration=30)
def analyze_text(text):
    """Analyze text and return score with visualization"""
    sys = load_system()

    with torch.no_grad():
        inputs = sys['tokenizer'](text, return_tensors='pt').to(sys['device'])
        out = sys['model'](**inputs, output_hidden_states=True)
        last_hidden = out.hidden_states[-1][0, -1, :]
        # Ensure proper dtype for adapter
        adapter_dtype = next(sys['adapter'].parameters()).dtype
        last_hidden = last_hidden.to(adapter_dtype)
        plv_pred = sys['adapter'](last_hidden.unsqueeze(0))
        plv_flat = plv_pred[0]
        plv_centered = plv_flat - sys['global_mean'].to(adapter_dtype)
        score = (plv_centered * sys['axis'].to(adapter_dtype)).sum().item()

    # Create minimal gauge like Streamlit
    gauge_min = min(-300, score - 50)
    gauge_max = max(300, score + 50)

    fig = go.Figure(go.Indicator(
        mode="number+gauge",
        value=score,
        gauge={
            'shape': "angular",
            'axis': {'range': [gauge_min, gauge_max], 'tickwidth': 0.5, 'tickcolor': '#ccc'},
            'bar': {'color': "#333", 'thickness': 0.15},
            'bgcolor': "white",
            'borderwidth': 1,
            'bordercolor': "#e0e0e0",
            'steps': [
                {'range': [gauge_min, -5], 'color': "#e8f5e9"},
                {'range': [-5, 5], 'color': "#fafafa"},
                {'range': [5, gauge_max], 'color': "#fff3e0"}
            ],
        },
        number={'font': {'size': 36, 'color': '#000'}}
    ))

    fig.update_layout(
        height=300,
        width=400,
        margin={'l': 30, 'r': 30, 't': 50, 'b': 30},
        paper_bgcolor='white',
        font={'color': '#666'}
    )

    if score > 5:
        interpretation = "**Syntactic dominance**  \nText patterns match brain activity during grammatical processing"
    elif score < -5:
        interpretation = "**Semantic dominance**  \nText patterns match brain activity during meaning comprehension"
    else:
        interpretation = "**Balanced**  \nMixed patterns - both structure and meaning equally present"

    # Create PLV matrix heatmap (reshape to 256x256)
    plv_np = plv_pred[0].cpu().numpy()
    plv_matrix = plv_np[:65536].reshape(256, 256)

    fig_plv = px.imshow(
        plv_matrix,
        color_continuous_scale='Viridis',
        aspect='auto'
    )
    fig_plv.update_layout(
        coloraxis_showscale=True,
        coloraxis=dict(
            colorbar=dict(
                thickness=10,
                len=0.7,
                title=dict(text="Synchrony", side="right"),
                tickfont=dict(size=10)
            )
        ),
        margin={'l': 0, 'r': 40, 't': 10, 'b': 0},
        height=300
    )
    fig_plv.update_xaxes(visible=False)
    fig_plv.update_yaxes(visible=False)

    return fig, interpretation, score, fig_plv

@spaces.GPU(duration=60)
def generate_steered(prompt, alpha, max_tokens):
    """Generate with custom steering"""
    sys = load_system()

    inputs = sys['tokenizer'](prompt, return_tensors='pt').to(sys['device'])
    generated_ids = inputs.input_ids.clone()

    for _ in range(max_tokens):
        outputs_model = sys['model'](generated_ids, output_hidden_states=True)
        hidden = outputs_model.hidden_states[-1][:, -1, :]

        # Ensure proper dtype for adapter
        adapter_dtype = next(sys['adapter'].parameters()).dtype
        hidden = hidden.to(adapter_dtype)

        if alpha != 0:
            hidden = hidden.detach().requires_grad_(True)
            plv_pred = sys['adapter'](hidden)
            score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype))
            grad = torch.autograd.grad(score, hidden, retain_graph=False)[0]
            grad = grad / (grad.norm() + 1e-8)
            hidden = hidden.detach() + alpha * grad.detach()

        with torch.no_grad():
            logits = sys['model'].lm_head(sys['model'].model.norm(hidden))
            probs = torch.softmax(logits / 0.8, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_ids = torch.cat([generated_ids, next_token], dim=-1)
            if next_token.item() == sys['tokenizer'].eos_token_id:
                break

    return sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True)

# Custom CSS to match Streamlit minimal design
custom_css = """
/* @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap'); */

/* Global font */
.gradio-container, .gradio-container * {
    font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
}

/* Clean header */
.main-header {
    font-size: 14px;
    font-weight: 300;
    letter-spacing: 2px;
    text-transform: uppercase;
    color: #666;
    margin-bottom: 8px;
}

.main-title {
    font-size: 48px;
    font-weight: 300;
    line-height: 1.1;
    letter-spacing: -1px;
    margin-bottom: 16px;
}

.subtitle {
    font-size: 18px;
    font-weight: 300;
    color: #666;
    line-height: 1.6;
}

/* Clean tabs like Streamlit */
.tabs {
    border-bottom: 1px solid #e0e0e0 !important;
}

.tab-nav button {
    background: none !important;
    border: none !important;
    border-bottom: 2px solid transparent !important;
    color: #666 !important;
    font-weight: 400 !important;
    font-size: 14px !important;
    padding: 8px 16px !important;
    text-transform: none !important;
}

.tab-nav button.selected {
    color: #000 !important;
    border-bottom-color: #000 !important;
}

/* Minimal buttons */
button.primary {
    background: white !important;
    border: 1px solid #000 !important;
    color: #000 !important;
    font-weight: 400 !important;
    padding: 10px 20px !important;
    transition: all 0.2s !important;
}

button.primary:hover {
    background: #000 !important;
    color: white !important;
}

/* Clean textboxes */
textarea, input[type="text"] {
    border: 1px solid #e0e0e0 !important;
    border-radius: 0 !important;
    font-size: 14px !important;
}

/* Section titles */
.section-title {
    font-size: 11px;
    font-weight: 500;
    letter-spacing: 1.5px;
    text-transform: uppercase;
    color: #999;
    margin: 24px 0 16px 0;
}

/* Value labels */
.value-label {
    font-size: 12px;
    color: #999;
    margin-bottom: 4px;
}

/* Remove gradio branding */
footer { display: none !important; }
.dark { display: none !important; }
"""

# Create interface
with gr.Blocks(title="Cognitive Proxy") as demo:

    # Header
    gr.HTML("""
    <div>
        <div class="main-header">Neural Language Interface</div>
        <div class="main-title">Cognitive Proxy</div>
        <div class="subtitle">Steering language models through brain-derived coordinate spaces.<br>
        Using MEG phase-locking patterns from 21 subjects as control geometry.</div>
        <div style="color: #999; font-size: 13px; margin-top: 16px;">Sandro Andric</div>
        <div style="color: #999; font-size: 11px; margin-top: 8px;">Demo model: TinyLlama-1.1B-Chat</div>
    </div>
    """)

    # How it works expander
    with gr.Accordion("How this works", open=False):
        gr.Markdown("""
        **What makes this special:** This AI is controlled by real human brain data.
        We recorded brain activity from 21 people listening to stories, discovered how their brains organize language,
        and now use those patterns to steer what the AI generates.

        **Try this:**
        1. Start with the **Compare** tab and choose **Educational**
        2. Click "Generate all variants" to see three versions side by side
        3. Notice how the left (concrete) version uses analogies while the right (abstract) uses logic
        4. The difference comes from steering along brain axes discovered from MEG recordings

        **The science:** Different brain regions activate for grammar vs meaning.
        We project the AI's internal states into this brain coordinate system and steer along the axis.
        """)

    with gr.Tabs():
        # Compare Tab
        with gr.TabItem("Compare"):
            gr.HTML('<div class="section-title">Comparative Analysis</div>')

            gr.Markdown("""
            See how brain steering affects AI output. Try **Educational** to see the difference between
            abstract explanations vs concrete analogies, or **Technical writing** to compare formal vs friendly tones.
            All controlled by brain patterns from 21 human subjects.
            """)

            with gr.Row():
                scenario = gr.Dropdown(
                    choices=["Educational", "Technical writing", "Free form"],
                    value="Educational",
                    label="Scenario",
                    container=False
                )

            prompt = gr.Textbox(
                value="Explain quantum entanglement in simple terms.",
                label="",
                placeholder="Enter your prompt...",
                lines=4
            )

            with gr.Row():
                max_tokens = gr.Slider(20, 150, 80, label="Max tokens", container=False)
                generate_btn = gr.Button("Generate all variants", variant="primary")

            gr.HTML('<div style="margin-top: 24px;"></div>')

            with gr.Row():
                with gr.Column():
                    gr.HTML('<div class="value-label">Concrete / Analogies</div>')
                    output_semantic = gr.Textbox(
                        label="",
                        lines=10,
                        interactive=False,
                        container=False
                    )
                    gr.Markdown("*Steered toward meaning patterns*", elem_classes=["caption"])

                with gr.Column():
                    gr.HTML('<div class="value-label">Baseline</div>')
                    output_baseline = gr.Textbox(
                        label="",
                        lines=10,
                        interactive=False,
                        container=False
                    )
                    gr.Markdown("*No brain steering*", elem_classes=["caption"])

                with gr.Column():
                    gr.HTML('<div class="value-label">Abstract / Logical</div>')
                    output_syntactic = gr.Textbox(
                        label="",
                        lines=10,
                        interactive=False,
                        container=False
                    )
                    gr.Markdown("*Steered toward structure patterns*", elem_classes=["caption"])

            generate_btn.click(
                generate_variants,
                inputs=[prompt, scenario, max_tokens],
                outputs=[output_semantic, output_baseline, output_syntactic]
            )

        # Inspect Tab
        with gr.TabItem("Inspect"):
            gr.HTML('<div class="section-title">Brain Space Projection</div>')

            gr.Markdown("""
            Enter any text to see how it aligns with brain patterns. The meter shows whether your text
            activates brain regions associated with grammar/structure (positive) or meaning/content (negative).
            """)

            with gr.Row():
                with gr.Column():
                    text_input = gr.Textbox(
                        value="The scientist discovered",
                        label="",
                        placeholder="Enter text to analyze...",
                        lines=6
                    )
                    analyze_btn = gr.Button("Project", variant="primary")

                with gr.Column():
                    gauge_plot = gr.Plot(label="")
                    interpretation = gr.Markdown("")

                    with gr.Accordion("What the number means", open=False):
                        gr.Markdown("""
                        - **Negative values (green)** = semantic/meaning focus
                        - **Positive values (amber)** = syntactic/grammar focus
                        - **Larger magnitude** = stronger pattern
                        - **Range** typically -300 to +300
                        """)

                    with gr.Accordion("View brain connectivity pattern", open=False):
                        gr.Markdown("""
                        Phase-Locking Value (PLV) shows how synchronized different brain regions are.
                        Brighter colors = stronger synchronization between sensor pairs.
                        Each pixel represents connectivity between two of 256 MEG sensors.
                        """)
                        plv_plot = gr.Plot(label="")

            def analyze_text_wrapper(text):
                fig, interp, _, fig_plv = analyze_text(text)
                return fig, interp, fig_plv

            analyze_btn.click(
                analyze_text_wrapper,
                inputs=[text_input],
                outputs=[gauge_plot, interpretation, plv_plot]
            )

        # Steer Tab
        with gr.TabItem("Steer"):
            gr.HTML('<div class="section-title">Neural Steering</div>')

            with gr.Row():
                with gr.Column(scale=2):
                    prompt_steer = gr.Textbox(
                        value="The scientist discovered",
                        label="",
                        placeholder="Enter prompt...",
                        lines=5
                    )

                with gr.Column(scale=1):
                    gr.HTML('<div class="value-label">Tokens</div>')
                    tokens_steer = gr.Slider(20, 150, 60, label="", container=False)

                    gr.HTML('<div class="value-label">Alpha</div>')
                    alpha_steer = gr.Slider(-5.0, 5.0, 0.0, 0.5, label="", container=False)
                    gr.Markdown("*negative → semantic | positive → syntactic*", elem_classes=["caption"])

                    steer_btn = gr.Button("Generate", variant="primary")

            gr.HTML('<div class="section-title">Output</div>')
            output_steer = gr.Textbox(label="", lines=8, interactive=False, container=False)

            steer_btn.click(
                generate_steered,
                inputs=[prompt_steer, alpha_steer, tokens_steer],
                outputs=[output_steer]
            )

    # Footer
    gr.HTML("""
    <div style="text-align: center; color: #999; font-size: 12px; padding: 40px 0 20px 0; border-top: 1px solid #e0e0e0; margin-top: 40px;">
    © 2025 Sandro Andric | <a href="https://ainthusiast.com" style="color: #999;">Ainthusiast.com</a>
    </div>
    """)

demo.launch(
    theme=gr.themes.Base(
        primary_hue="gray",
        neutral_hue="gray",
        text_size="md",
        spacing_size="lg",
        radius_size="none",
    ),
    css=custom_css,
    ssr_mode=False
)