Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Cognitive Proxy - Brain-Steered Language Model | |
| Hugging Face Spaces deployment | |
| Author: Sandro Andric | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import pickle | |
| import os | |
| from pathlib import Path | |
| from sklearn.decomposition import PCA | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import spaces # For ZeroGPU on Hugging Face | |
| # --- CONFIG --- | |
| import os | |
| from pathlib import Path | |
| # Get the directory of this script | |
| SCRIPT_DIR = Path(__file__).parent if __file__ else Path.cwd() | |
| # Try multiple possible locations for the model files | |
| if (SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl").exists(): | |
| ATLAS_PATH = str(SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl") | |
| ADAPTER_PATH = str(SCRIPT_DIR / "results" / "tinyllama_adapter_direct.pt") | |
| elif (SCRIPT_DIR / "final_atlas_256_vocab.pkl").exists(): | |
| ATLAS_PATH = str(SCRIPT_DIR / "final_atlas_256_vocab.pkl") | |
| ADAPTER_PATH = str(SCRIPT_DIR / "tinyllama_adapter_direct.pt") | |
| else: | |
| # Fallback to expected location | |
| ATLAS_PATH = "results/final_atlas_256_vocab.pkl" | |
| ADAPTER_PATH = "results/tinyllama_adapter_direct.pt" | |
| print(f"Atlas path: {ATLAS_PATH}") | |
| print(f"Adapter path: {ADAPTER_PATH}") | |
| MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| # --- ADAPTER CLASS --- | |
| class TinyLlamaAdapterDirect(nn.Module): | |
| def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=65536): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.LayerNorm(hidden_dim // 2), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim // 2, output_dim), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| # Global system cache | |
| system = None | |
| def load_system(): | |
| global system | |
| if system is not None: | |
| return system | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Use float32 for CPU, float16 for GPU | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| try: | |
| # Try new parameter name first | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=dtype).to(device) | |
| except TypeError: | |
| # Fall back to old parameter name | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device) | |
| model.eval() | |
| adapter = TinyLlamaAdapterDirect().to(device).to(dtype) | |
| if os.path.exists(ADAPTER_PATH): | |
| adapter.load_state_dict(torch.load(ADAPTER_PATH, map_location=device, weights_only=True)) | |
| adapter.eval() | |
| if os.path.exists(ATLAS_PATH): | |
| print(f"Loading atlas from {ATLAS_PATH}") | |
| with open(ATLAS_PATH, 'rb') as f: | |
| data = pickle.load(f) | |
| if isinstance(data, dict): | |
| print(f"Atlas data keys: {list(data.keys())[:5]}") | |
| if 'means' in data: | |
| atlas = data['means'] | |
| print(f"Using 'means' key, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") | |
| else: | |
| atlas = data | |
| print(f"Using data directly, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") | |
| else: | |
| atlas = data | |
| print(f"Atlas is not a dict, type: {type(data)}") | |
| else: | |
| print(f"Atlas file not found at {ATLAS_PATH}") | |
| atlas = {} | |
| # Ensure atlas is valid | |
| if not atlas or not isinstance(atlas, dict): | |
| print(f"Warning: Atlas is empty or invalid, using fallback") | |
| atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} | |
| words = list(atlas.keys()) | |
| print(f"Loaded atlas with {len(words)} words") | |
| if len(words) < 2: | |
| print(f"Warning: Not enough words in atlas ({len(words)}), using fallback") | |
| atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} | |
| words = list(atlas.keys()) | |
| # Handle both 256x256 and flat arrays | |
| first_val = np.array(atlas[words[0]]) | |
| if first_val.shape == (256, 256): | |
| plv_matrix = np.array([np.array(atlas[w]).flatten() for w in words]) | |
| else: | |
| plv_matrix = np.array([np.array(atlas[w]) for w in words]) | |
| # Ensure matrix is 2D | |
| if len(plv_matrix.shape) == 1 or plv_matrix.shape[0] < 2: | |
| print(f"Warning: Invalid PLV matrix shape {plv_matrix.shape}, using fallback") | |
| plv_matrix = np.random.randn(10, 65536) | |
| pca = PCA(n_components=min(10, plv_matrix.shape[0] - 1)) | |
| pca.fit(plv_matrix) | |
| pc1_axis = pca.components_[0] | |
| pc1_axis = pc1_axis / np.linalg.norm(pc1_axis) | |
| global_mean = plv_matrix.mean(axis=0) | |
| system = { | |
| 'model': model, | |
| 'tokenizer': tokenizer, | |
| 'adapter': adapter, | |
| 'axis': torch.tensor(pc1_axis, dtype=torch.float32).to(device), | |
| 'global_mean': torch.tensor(global_mean, dtype=torch.float32).to(device), | |
| 'device': device | |
| } | |
| return system | |
| def generate_variants(prompt, scenario, max_tokens): | |
| """Generate all three variants""" | |
| sys = load_system() | |
| if scenario == "Educational": | |
| prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" | |
| alpha_strength = 5.0 | |
| elif scenario == "Technical writing": | |
| prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" | |
| alpha_strength = 5.0 | |
| else: | |
| prompt_formatted = prompt | |
| alpha_strength = 3.0 | |
| outputs = [] | |
| for alpha in [-alpha_strength, 0, alpha_strength]: | |
| inputs = sys['tokenizer'](prompt_formatted, return_tensors='pt').to(sys['device']) | |
| generated_ids = inputs.input_ids.clone() | |
| for _ in range(max_tokens): | |
| outputs_model = sys['model'](generated_ids, output_hidden_states=True) | |
| hidden = outputs_model.hidden_states[-1][:, -1, :] | |
| # Ensure proper dtype for adapter | |
| adapter_dtype = next(sys['adapter'].parameters()).dtype | |
| hidden = hidden.to(adapter_dtype) | |
| if alpha != 0: | |
| hidden = hidden.detach().requires_grad_(True) | |
| plv_pred = sys['adapter'](hidden) | |
| score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) | |
| grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] | |
| grad = grad / (grad.norm() + 1e-8) | |
| hidden = hidden.detach() + alpha * grad.detach() | |
| with torch.no_grad(): | |
| logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) | |
| probs = torch.softmax(logits / 0.8, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) | |
| if next_token.item() == sys['tokenizer'].eos_token_id: | |
| break | |
| text = sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) | |
| if "<|assistant|>" in text: | |
| text = text.split("<|assistant|>")[-1].strip() | |
| outputs.append(text) | |
| return outputs[0], outputs[1], outputs[2] | |
| def analyze_text(text): | |
| """Analyze text and return score with visualization""" | |
| sys = load_system() | |
| with torch.no_grad(): | |
| inputs = sys['tokenizer'](text, return_tensors='pt').to(sys['device']) | |
| out = sys['model'](**inputs, output_hidden_states=True) | |
| last_hidden = out.hidden_states[-1][0, -1, :] | |
| # Ensure proper dtype for adapter | |
| adapter_dtype = next(sys['adapter'].parameters()).dtype | |
| last_hidden = last_hidden.to(adapter_dtype) | |
| plv_pred = sys['adapter'](last_hidden.unsqueeze(0)) | |
| plv_flat = plv_pred[0] | |
| plv_centered = plv_flat - sys['global_mean'].to(adapter_dtype) | |
| score = (plv_centered * sys['axis'].to(adapter_dtype)).sum().item() | |
| # Create minimal gauge like Streamlit | |
| gauge_min = min(-300, score - 50) | |
| gauge_max = max(300, score + 50) | |
| fig = go.Figure(go.Indicator( | |
| mode="number+gauge", | |
| value=score, | |
| gauge={ | |
| 'shape': "angular", | |
| 'axis': {'range': [gauge_min, gauge_max], 'tickwidth': 0.5, 'tickcolor': '#ccc'}, | |
| 'bar': {'color': "#333", 'thickness': 0.15}, | |
| 'bgcolor': "white", | |
| 'borderwidth': 1, | |
| 'bordercolor': "#e0e0e0", | |
| 'steps': [ | |
| {'range': [gauge_min, -5], 'color': "#e8f5e9"}, | |
| {'range': [-5, 5], 'color': "#fafafa"}, | |
| {'range': [5, gauge_max], 'color': "#fff3e0"} | |
| ], | |
| }, | |
| number={'font': {'size': 36, 'color': '#000'}} | |
| )) | |
| fig.update_layout( | |
| height=300, | |
| width=400, | |
| margin={'l': 30, 'r': 30, 't': 50, 'b': 30}, | |
| paper_bgcolor='white', | |
| font={'color': '#666'} | |
| ) | |
| if score > 5: | |
| interpretation = "**Syntactic dominance** \nText patterns match brain activity during grammatical processing" | |
| elif score < -5: | |
| interpretation = "**Semantic dominance** \nText patterns match brain activity during meaning comprehension" | |
| else: | |
| interpretation = "**Balanced** \nMixed patterns - both structure and meaning equally present" | |
| # Create PLV matrix heatmap (reshape to 256x256) | |
| plv_np = plv_pred[0].cpu().numpy() | |
| plv_matrix = plv_np[:65536].reshape(256, 256) | |
| fig_plv = px.imshow( | |
| plv_matrix, | |
| color_continuous_scale='Viridis', | |
| aspect='auto' | |
| ) | |
| fig_plv.update_layout( | |
| coloraxis_showscale=True, | |
| coloraxis=dict( | |
| colorbar=dict( | |
| thickness=10, | |
| len=0.7, | |
| title=dict(text="Synchrony", side="right"), | |
| tickfont=dict(size=10) | |
| ) | |
| ), | |
| margin={'l': 0, 'r': 40, 't': 10, 'b': 0}, | |
| height=300 | |
| ) | |
| fig_plv.update_xaxes(visible=False) | |
| fig_plv.update_yaxes(visible=False) | |
| return fig, interpretation, score, fig_plv | |
| def generate_steered(prompt, alpha, max_tokens): | |
| """Generate with custom steering""" | |
| sys = load_system() | |
| inputs = sys['tokenizer'](prompt, return_tensors='pt').to(sys['device']) | |
| generated_ids = inputs.input_ids.clone() | |
| for _ in range(max_tokens): | |
| outputs_model = sys['model'](generated_ids, output_hidden_states=True) | |
| hidden = outputs_model.hidden_states[-1][:, -1, :] | |
| # Ensure proper dtype for adapter | |
| adapter_dtype = next(sys['adapter'].parameters()).dtype | |
| hidden = hidden.to(adapter_dtype) | |
| if alpha != 0: | |
| hidden = hidden.detach().requires_grad_(True) | |
| plv_pred = sys['adapter'](hidden) | |
| score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) | |
| grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] | |
| grad = grad / (grad.norm() + 1e-8) | |
| hidden = hidden.detach() + alpha * grad.detach() | |
| with torch.no_grad(): | |
| logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) | |
| probs = torch.softmax(logits / 0.8, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) | |
| if next_token.item() == sys['tokenizer'].eos_token_id: | |
| break | |
| return sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) | |
| # Custom CSS to match Streamlit minimal design | |
| custom_css = """ | |
| /* @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap'); */ | |
| /* Global font */ | |
| .gradio-container, .gradio-container * { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; | |
| } | |
| /* Clean header */ | |
| .main-header { | |
| font-size: 14px; | |
| font-weight: 300; | |
| letter-spacing: 2px; | |
| text-transform: uppercase; | |
| color: #666; | |
| margin-bottom: 8px; | |
| } | |
| .main-title { | |
| font-size: 48px; | |
| font-weight: 300; | |
| line-height: 1.1; | |
| letter-spacing: -1px; | |
| margin-bottom: 16px; | |
| } | |
| .subtitle { | |
| font-size: 18px; | |
| font-weight: 300; | |
| color: #666; | |
| line-height: 1.6; | |
| } | |
| /* Clean tabs like Streamlit */ | |
| .tabs { | |
| border-bottom: 1px solid #e0e0e0 !important; | |
| } | |
| .tab-nav button { | |
| background: none !important; | |
| border: none !important; | |
| border-bottom: 2px solid transparent !important; | |
| color: #666 !important; | |
| font-weight: 400 !important; | |
| font-size: 14px !important; | |
| padding: 8px 16px !important; | |
| text-transform: none !important; | |
| } | |
| .tab-nav button.selected { | |
| color: #000 !important; | |
| border-bottom-color: #000 !important; | |
| } | |
| /* Minimal buttons */ | |
| button.primary { | |
| background: white !important; | |
| border: 1px solid #000 !important; | |
| color: #000 !important; | |
| font-weight: 400 !important; | |
| padding: 10px 20px !important; | |
| transition: all 0.2s !important; | |
| } | |
| button.primary:hover { | |
| background: #000 !important; | |
| color: white !important; | |
| } | |
| /* Clean textboxes */ | |
| textarea, input[type="text"] { | |
| border: 1px solid #e0e0e0 !important; | |
| border-radius: 0 !important; | |
| font-size: 14px !important; | |
| } | |
| /* Section titles */ | |
| .section-title { | |
| font-size: 11px; | |
| font-weight: 500; | |
| letter-spacing: 1.5px; | |
| text-transform: uppercase; | |
| color: #999; | |
| margin: 24px 0 16px 0; | |
| } | |
| /* Value labels */ | |
| .value-label { | |
| font-size: 12px; | |
| color: #999; | |
| margin-bottom: 4px; | |
| } | |
| /* Remove gradio branding */ | |
| footer { display: none !important; } | |
| .dark { display: none !important; } | |
| """ | |
| # Create interface | |
| with gr.Blocks(title="Cognitive Proxy") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div> | |
| <div class="main-header">Neural Language Interface</div> | |
| <div class="main-title">Cognitive Proxy</div> | |
| <div class="subtitle">Steering language models through brain-derived coordinate spaces.<br> | |
| Using MEG phase-locking patterns from 21 subjects as control geometry.</div> | |
| <div style="color: #999; font-size: 13px; margin-top: 16px;">Sandro Andric</div> | |
| <div style="color: #999; font-size: 11px; margin-top: 8px;">Demo model: TinyLlama-1.1B-Chat</div> | |
| </div> | |
| """) | |
| # How it works expander | |
| with gr.Accordion("How this works", open=False): | |
| gr.Markdown(""" | |
| **What makes this special:** This AI is controlled by real human brain data. | |
| We recorded brain activity from 21 people listening to stories, discovered how their brains organize language, | |
| and now use those patterns to steer what the AI generates. | |
| **Try this:** | |
| 1. Start with the **Compare** tab and choose **Educational** | |
| 2. Click "Generate all variants" to see three versions side by side | |
| 3. Notice how the left (concrete) version uses analogies while the right (abstract) uses logic | |
| 4. The difference comes from steering along brain axes discovered from MEG recordings | |
| **The science:** Different brain regions activate for grammar vs meaning. | |
| We project the AI's internal states into this brain coordinate system and steer along the axis. | |
| """) | |
| with gr.Tabs(): | |
| # Compare Tab | |
| with gr.TabItem("Compare"): | |
| gr.HTML('<div class="section-title">Comparative Analysis</div>') | |
| gr.Markdown(""" | |
| See how brain steering affects AI output. Try **Educational** to see the difference between | |
| abstract explanations vs concrete analogies, or **Technical writing** to compare formal vs friendly tones. | |
| All controlled by brain patterns from 21 human subjects. | |
| """) | |
| with gr.Row(): | |
| scenario = gr.Dropdown( | |
| choices=["Educational", "Technical writing", "Free form"], | |
| value="Educational", | |
| label="Scenario", | |
| container=False | |
| ) | |
| prompt = gr.Textbox( | |
| value="Explain quantum entanglement in simple terms.", | |
| label="", | |
| placeholder="Enter your prompt...", | |
| lines=4 | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(20, 150, 80, label="Max tokens", container=False) | |
| generate_btn = gr.Button("Generate all variants", variant="primary") | |
| gr.HTML('<div style="margin-top: 24px;"></div>') | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML('<div class="value-label">Concrete / Analogies</div>') | |
| output_semantic = gr.Textbox( | |
| label="", | |
| lines=10, | |
| interactive=False, | |
| container=False | |
| ) | |
| gr.Markdown("*Steered toward meaning patterns*", elem_classes=["caption"]) | |
| with gr.Column(): | |
| gr.HTML('<div class="value-label">Baseline</div>') | |
| output_baseline = gr.Textbox( | |
| label="", | |
| lines=10, | |
| interactive=False, | |
| container=False | |
| ) | |
| gr.Markdown("*No brain steering*", elem_classes=["caption"]) | |
| with gr.Column(): | |
| gr.HTML('<div class="value-label">Abstract / Logical</div>') | |
| output_syntactic = gr.Textbox( | |
| label="", | |
| lines=10, | |
| interactive=False, | |
| container=False | |
| ) | |
| gr.Markdown("*Steered toward structure patterns*", elem_classes=["caption"]) | |
| generate_btn.click( | |
| generate_variants, | |
| inputs=[prompt, scenario, max_tokens], | |
| outputs=[output_semantic, output_baseline, output_syntactic] | |
| ) | |
| # Inspect Tab | |
| with gr.TabItem("Inspect"): | |
| gr.HTML('<div class="section-title">Brain Space Projection</div>') | |
| gr.Markdown(""" | |
| Enter any text to see how it aligns with brain patterns. The meter shows whether your text | |
| activates brain regions associated with grammar/structure (positive) or meaning/content (negative). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| value="The scientist discovered", | |
| label="", | |
| placeholder="Enter text to analyze...", | |
| lines=6 | |
| ) | |
| analyze_btn = gr.Button("Project", variant="primary") | |
| with gr.Column(): | |
| gauge_plot = gr.Plot(label="") | |
| interpretation = gr.Markdown("") | |
| with gr.Accordion("What the number means", open=False): | |
| gr.Markdown(""" | |
| - **Negative values (green)** = semantic/meaning focus | |
| - **Positive values (amber)** = syntactic/grammar focus | |
| - **Larger magnitude** = stronger pattern | |
| - **Range** typically -300 to +300 | |
| """) | |
| with gr.Accordion("View brain connectivity pattern", open=False): | |
| gr.Markdown(""" | |
| Phase-Locking Value (PLV) shows how synchronized different brain regions are. | |
| Brighter colors = stronger synchronization between sensor pairs. | |
| Each pixel represents connectivity between two of 256 MEG sensors. | |
| """) | |
| plv_plot = gr.Plot(label="") | |
| def analyze_text_wrapper(text): | |
| fig, interp, _, fig_plv = analyze_text(text) | |
| return fig, interp, fig_plv | |
| analyze_btn.click( | |
| analyze_text_wrapper, | |
| inputs=[text_input], | |
| outputs=[gauge_plot, interpretation, plv_plot] | |
| ) | |
| # Steer Tab | |
| with gr.TabItem("Steer"): | |
| gr.HTML('<div class="section-title">Neural Steering</div>') | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_steer = gr.Textbox( | |
| value="The scientist discovered", | |
| label="", | |
| placeholder="Enter prompt...", | |
| lines=5 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.HTML('<div class="value-label">Tokens</div>') | |
| tokens_steer = gr.Slider(20, 150, 60, label="", container=False) | |
| gr.HTML('<div class="value-label">Alpha</div>') | |
| alpha_steer = gr.Slider(-5.0, 5.0, 0.0, 0.5, label="", container=False) | |
| gr.Markdown("*negative → semantic | positive → syntactic*", elem_classes=["caption"]) | |
| steer_btn = gr.Button("Generate", variant="primary") | |
| gr.HTML('<div class="section-title">Output</div>') | |
| output_steer = gr.Textbox(label="", lines=8, interactive=False, container=False) | |
| steer_btn.click( | |
| generate_steered, | |
| inputs=[prompt_steer, alpha_steer, tokens_steer], | |
| outputs=[output_steer] | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; color: #999; font-size: 12px; padding: 40px 0 20px 0; border-top: 1px solid #e0e0e0; margin-top: 40px;"> | |
| © 2025 Sandro Andric | <a href="https://ainthusiast.com" style="color: #999;">Ainthusiast.com</a> | |
| </div> | |
| """) | |
| demo.launch( | |
| theme=gr.themes.Base( | |
| primary_hue="gray", | |
| neutral_hue="gray", | |
| text_size="md", | |
| spacing_size="lg", | |
| radius_size="none", | |
| ), | |
| css=custom_css, | |
| ssr_mode=False | |
| ) |