""" Cognitive Proxy - Brain-Steered Language Model Hugging Face Spaces deployment Author: Sandro Andric """ import gradio as gr import torch import torch.nn as nn import numpy as np import pickle import os from pathlib import Path from sklearn.decomposition import PCA from transformers import AutoTokenizer, AutoModelForCausalLM import plotly.graph_objects as go import plotly.express as px import spaces # For ZeroGPU on Hugging Face # --- CONFIG --- import os from pathlib import Path # Get the directory of this script SCRIPT_DIR = Path(__file__).parent if __file__ else Path.cwd() # Try multiple possible locations for the model files if (SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl").exists(): ATLAS_PATH = str(SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl") ADAPTER_PATH = str(SCRIPT_DIR / "results" / "tinyllama_adapter_direct.pt") elif (SCRIPT_DIR / "final_atlas_256_vocab.pkl").exists(): ATLAS_PATH = str(SCRIPT_DIR / "final_atlas_256_vocab.pkl") ADAPTER_PATH = str(SCRIPT_DIR / "tinyllama_adapter_direct.pt") else: # Fallback to expected location ATLAS_PATH = "results/final_atlas_256_vocab.pkl" ADAPTER_PATH = "results/tinyllama_adapter_direct.pt" print(f"Atlas path: {ATLAS_PATH}") print(f"Adapter path: {ADAPTER_PATH}") MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # --- ADAPTER CLASS --- class TinyLlamaAdapterDirect(nn.Module): def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=65536): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, output_dim), ) def forward(self, x): return self.net(x) # Global system cache system = None def load_system(): global system if system is not None: return system device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token # Use float32 for CPU, float16 for GPU dtype = torch.float16 if torch.cuda.is_available() else torch.float32 try: # Try new parameter name first model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=dtype).to(device) except TypeError: # Fall back to old parameter name model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device) model.eval() adapter = TinyLlamaAdapterDirect().to(device).to(dtype) if os.path.exists(ADAPTER_PATH): adapter.load_state_dict(torch.load(ADAPTER_PATH, map_location=device, weights_only=True)) adapter.eval() if os.path.exists(ATLAS_PATH): print(f"Loading atlas from {ATLAS_PATH}") with open(ATLAS_PATH, 'rb') as f: data = pickle.load(f) if isinstance(data, dict): print(f"Atlas data keys: {list(data.keys())[:5]}") if 'means' in data: atlas = data['means'] print(f"Using 'means' key, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") else: atlas = data print(f"Using data directly, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") else: atlas = data print(f"Atlas is not a dict, type: {type(data)}") else: print(f"Atlas file not found at {ATLAS_PATH}") atlas = {} # Ensure atlas is valid if not atlas or not isinstance(atlas, dict): print(f"Warning: Atlas is empty or invalid, using fallback") atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} words = list(atlas.keys()) print(f"Loaded atlas with {len(words)} words") if len(words) < 2: print(f"Warning: Not enough words in atlas ({len(words)}), using fallback") atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} words = list(atlas.keys()) # Handle both 256x256 and flat arrays first_val = np.array(atlas[words[0]]) if first_val.shape == (256, 256): plv_matrix = np.array([np.array(atlas[w]).flatten() for w in words]) else: plv_matrix = np.array([np.array(atlas[w]) for w in words]) # Ensure matrix is 2D if len(plv_matrix.shape) == 1 or plv_matrix.shape[0] < 2: print(f"Warning: Invalid PLV matrix shape {plv_matrix.shape}, using fallback") plv_matrix = np.random.randn(10, 65536) pca = PCA(n_components=min(10, plv_matrix.shape[0] - 1)) pca.fit(plv_matrix) pc1_axis = pca.components_[0] pc1_axis = pc1_axis / np.linalg.norm(pc1_axis) global_mean = plv_matrix.mean(axis=0) system = { 'model': model, 'tokenizer': tokenizer, 'adapter': adapter, 'axis': torch.tensor(pc1_axis, dtype=torch.float32).to(device), 'global_mean': torch.tensor(global_mean, dtype=torch.float32).to(device), 'device': device } return system @spaces.GPU(duration=60) def generate_variants(prompt, scenario, max_tokens): """Generate all three variants""" sys = load_system() if scenario == "Educational": prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" alpha_strength = 5.0 elif scenario == "Technical writing": prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" alpha_strength = 5.0 else: prompt_formatted = prompt alpha_strength = 3.0 outputs = [] for alpha in [-alpha_strength, 0, alpha_strength]: inputs = sys['tokenizer'](prompt_formatted, return_tensors='pt').to(sys['device']) generated_ids = inputs.input_ids.clone() for _ in range(max_tokens): outputs_model = sys['model'](generated_ids, output_hidden_states=True) hidden = outputs_model.hidden_states[-1][:, -1, :] # Ensure proper dtype for adapter adapter_dtype = next(sys['adapter'].parameters()).dtype hidden = hidden.to(adapter_dtype) if alpha != 0: hidden = hidden.detach().requires_grad_(True) plv_pred = sys['adapter'](hidden) score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] grad = grad / (grad.norm() + 1e-8) hidden = hidden.detach() + alpha * grad.detach() with torch.no_grad(): logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) probs = torch.softmax(logits / 0.8, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_ids = torch.cat([generated_ids, next_token], dim=-1) if next_token.item() == sys['tokenizer'].eos_token_id: break text = sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) if "<|assistant|>" in text: text = text.split("<|assistant|>")[-1].strip() outputs.append(text) return outputs[0], outputs[1], outputs[2] @spaces.GPU(duration=30) def analyze_text(text): """Analyze text and return score with visualization""" sys = load_system() with torch.no_grad(): inputs = sys['tokenizer'](text, return_tensors='pt').to(sys['device']) out = sys['model'](**inputs, output_hidden_states=True) last_hidden = out.hidden_states[-1][0, -1, :] # Ensure proper dtype for adapter adapter_dtype = next(sys['adapter'].parameters()).dtype last_hidden = last_hidden.to(adapter_dtype) plv_pred = sys['adapter'](last_hidden.unsqueeze(0)) plv_flat = plv_pred[0] plv_centered = plv_flat - sys['global_mean'].to(adapter_dtype) score = (plv_centered * sys['axis'].to(adapter_dtype)).sum().item() # Create minimal gauge like Streamlit gauge_min = min(-300, score - 50) gauge_max = max(300, score + 50) fig = go.Figure(go.Indicator( mode="number+gauge", value=score, gauge={ 'shape': "angular", 'axis': {'range': [gauge_min, gauge_max], 'tickwidth': 0.5, 'tickcolor': '#ccc'}, 'bar': {'color': "#333", 'thickness': 0.15}, 'bgcolor': "white", 'borderwidth': 1, 'bordercolor': "#e0e0e0", 'steps': [ {'range': [gauge_min, -5], 'color': "#e8f5e9"}, {'range': [-5, 5], 'color': "#fafafa"}, {'range': [5, gauge_max], 'color': "#fff3e0"} ], }, number={'font': {'size': 36, 'color': '#000'}} )) fig.update_layout( height=300, width=400, margin={'l': 30, 'r': 30, 't': 50, 'b': 30}, paper_bgcolor='white', font={'color': '#666'} ) if score > 5: interpretation = "**Syntactic dominance** \nText patterns match brain activity during grammatical processing" elif score < -5: interpretation = "**Semantic dominance** \nText patterns match brain activity during meaning comprehension" else: interpretation = "**Balanced** \nMixed patterns - both structure and meaning equally present" # Create PLV matrix heatmap (reshape to 256x256) plv_np = plv_pred[0].cpu().numpy() plv_matrix = plv_np[:65536].reshape(256, 256) fig_plv = px.imshow( plv_matrix, color_continuous_scale='Viridis', aspect='auto' ) fig_plv.update_layout( coloraxis_showscale=True, coloraxis=dict( colorbar=dict( thickness=10, len=0.7, title=dict(text="Synchrony", side="right"), tickfont=dict(size=10) ) ), margin={'l': 0, 'r': 40, 't': 10, 'b': 0}, height=300 ) fig_plv.update_xaxes(visible=False) fig_plv.update_yaxes(visible=False) return fig, interpretation, score, fig_plv @spaces.GPU(duration=60) def generate_steered(prompt, alpha, max_tokens): """Generate with custom steering""" sys = load_system() inputs = sys['tokenizer'](prompt, return_tensors='pt').to(sys['device']) generated_ids = inputs.input_ids.clone() for _ in range(max_tokens): outputs_model = sys['model'](generated_ids, output_hidden_states=True) hidden = outputs_model.hidden_states[-1][:, -1, :] # Ensure proper dtype for adapter adapter_dtype = next(sys['adapter'].parameters()).dtype hidden = hidden.to(adapter_dtype) if alpha != 0: hidden = hidden.detach().requires_grad_(True) plv_pred = sys['adapter'](hidden) score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] grad = grad / (grad.norm() + 1e-8) hidden = hidden.detach() + alpha * grad.detach() with torch.no_grad(): logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) probs = torch.softmax(logits / 0.8, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_ids = torch.cat([generated_ids, next_token], dim=-1) if next_token.item() == sys['tokenizer'].eos_token_id: break return sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) # Custom CSS to match Streamlit minimal design custom_css = """ /* @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap'); */ /* Global font */ .gradio-container, .gradio-container * { font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; } /* Clean header */ .main-header { font-size: 14px; font-weight: 300; letter-spacing: 2px; text-transform: uppercase; color: #666; margin-bottom: 8px; } .main-title { font-size: 48px; font-weight: 300; line-height: 1.1; letter-spacing: -1px; margin-bottom: 16px; } .subtitle { font-size: 18px; font-weight: 300; color: #666; line-height: 1.6; } /* Clean tabs like Streamlit */ .tabs { border-bottom: 1px solid #e0e0e0 !important; } .tab-nav button { background: none !important; border: none !important; border-bottom: 2px solid transparent !important; color: #666 !important; font-weight: 400 !important; font-size: 14px !important; padding: 8px 16px !important; text-transform: none !important; } .tab-nav button.selected { color: #000 !important; border-bottom-color: #000 !important; } /* Minimal buttons */ button.primary { background: white !important; border: 1px solid #000 !important; color: #000 !important; font-weight: 400 !important; padding: 10px 20px !important; transition: all 0.2s !important; } button.primary:hover { background: #000 !important; color: white !important; } /* Clean textboxes */ textarea, input[type="text"] { border: 1px solid #e0e0e0 !important; border-radius: 0 !important; font-size: 14px !important; } /* Section titles */ .section-title { font-size: 11px; font-weight: 500; letter-spacing: 1.5px; text-transform: uppercase; color: #999; margin: 24px 0 16px 0; } /* Value labels */ .value-label { font-size: 12px; color: #999; margin-bottom: 4px; } /* Remove gradio branding */ footer { display: none !important; } .dark { display: none !important; } """ # Create interface with gr.Blocks(title="Cognitive Proxy") as demo: # Header gr.HTML("""