AI-nthusiast commited on
Commit
20f2285
·
1 Parent(s): f61308a

Add PLV matrix heatmap to Inspect tab

Browse files
Files changed (1) hide show
  1. app.py +38 -4
app.py CHANGED
@@ -14,6 +14,7 @@ from pathlib import Path
14
  from sklearn.decomposition import PCA
15
  from transformers import AutoTokenizer, AutoModelForCausalLM
16
  import plotly.graph_objects as go
 
17
  import spaces # For ZeroGPU on Hugging Face
18
 
19
  # --- CONFIG ---
@@ -255,7 +256,32 @@ def analyze_text(text):
255
  else:
256
  interpretation = "**Balanced** \nMixed patterns - both structure and meaning equally present"
257
 
258
- return fig, interpretation, score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  @spaces.GPU(duration=60)
261
  def generate_steered(prompt, alpha, max_tokens):
@@ -523,14 +549,22 @@ with gr.Blocks(title="Cognitive Proxy") as demo:
523
  - **Range** typically -300 to +300
524
  """)
525
 
 
 
 
 
 
 
 
 
526
  def analyze_text_wrapper(text):
527
- fig, interp, _ = analyze_text(text) # Ignore the score
528
- return fig, interp
529
 
530
  analyze_btn.click(
531
  analyze_text_wrapper,
532
  inputs=[text_input],
533
- outputs=[gauge_plot, interpretation]
534
  )
535
 
536
  # Steer Tab
 
14
  from sklearn.decomposition import PCA
15
  from transformers import AutoTokenizer, AutoModelForCausalLM
16
  import plotly.graph_objects as go
17
+ import plotly.express as px
18
  import spaces # For ZeroGPU on Hugging Face
19
 
20
  # --- CONFIG ---
 
256
  else:
257
  interpretation = "**Balanced** \nMixed patterns - both structure and meaning equally present"
258
 
259
+ # Create PLV matrix heatmap (reshape to 256x256)
260
+ plv_np = plv_pred[0].cpu().numpy()
261
+ plv_matrix = plv_np[:65536].reshape(256, 256)
262
+
263
+ fig_plv = px.imshow(
264
+ plv_matrix,
265
+ color_continuous_scale='Viridis',
266
+ aspect='auto'
267
+ )
268
+ fig_plv.update_layout(
269
+ coloraxis_showscale=True,
270
+ coloraxis=dict(
271
+ colorbar=dict(
272
+ thickness=10,
273
+ len=0.7,
274
+ title=dict(text="Synchrony", side="right"),
275
+ tickfont=dict(size=10)
276
+ )
277
+ ),
278
+ margin={'l': 0, 'r': 40, 't': 10, 'b': 0},
279
+ height=300
280
+ )
281
+ fig_plv.update_xaxes(visible=False)
282
+ fig_plv.update_yaxes(visible=False)
283
+
284
+ return fig, interpretation, score, fig_plv
285
 
286
  @spaces.GPU(duration=60)
287
  def generate_steered(prompt, alpha, max_tokens):
 
549
  - **Range** typically -300 to +300
550
  """)
551
 
552
+ with gr.Accordion("View brain connectivity pattern", open=False):
553
+ gr.Markdown("""
554
+ Phase-Locking Value (PLV) shows how synchronized different brain regions are.
555
+ Brighter colors = stronger synchronization between sensor pairs.
556
+ Each pixel represents connectivity between two of 256 MEG sensors.
557
+ """)
558
+ plv_plot = gr.Plot(label="")
559
+
560
  def analyze_text_wrapper(text):
561
+ fig, interp, _, fig_plv = analyze_text(text)
562
+ return fig, interp, fig_plv
563
 
564
  analyze_btn.click(
565
  analyze_text_wrapper,
566
  inputs=[text_input],
567
+ outputs=[gauge_plot, interpretation, plv_plot]
568
  )
569
 
570
  # Steer Tab