OliverPerrin commited on
Commit
18fc263
·
1 Parent(s): 5fde4fb

Simplify Gradio demo to fix schema bug

Browse files
Files changed (1) hide show
  1. scripts/demo_gradio.py +79 -526
scripts/demo_gradio.py CHANGED
@@ -1,587 +1,140 @@
1
- """
2
- Minimal Gradio demo for the LexiMind multitask model.
3
- Shows raw model outputs without any post-processing tricks.
4
- """
5
 
6
  from __future__ import annotations
7
 
8
  import json
9
- import re
10
  import sys
11
- from datetime import datetime
12
  from pathlib import Path
13
- from tempfile import NamedTemporaryFile
14
- from typing import Iterable, Sequence
15
 
16
  import gradio as gr
17
- import matplotlib.pyplot as plt
18
- import pandas as pd
19
- import seaborn as sns
20
- import torch
21
- from gradio.themes import Soft
22
- from matplotlib.figure import Figure
23
 
24
- # Make local packages importable when running the script directly
25
  SCRIPT_DIR = Path(__file__).resolve().parent
26
-
27
-
28
- def guess_project_root(script_dir: Path) -> Path:
29
- """Attempt to locate the LexiMind repo root even when deployed under /app."""
30
- markers = ("pyproject.toml", "setup.py", "README.md")
31
- candidates = [script_dir]
32
- candidates.extend(script_dir.parents)
33
- candidates.extend(
34
- [
35
- script_dir / "LexiMind",
36
- script_dir.parent / "LexiMind",
37
- Path("/LexiMind"),
38
- ]
39
- )
40
-
41
- seen: set[Path] = set()
42
- for candidate in candidates:
43
- if candidate in seen:
44
- continue
45
- seen.add(candidate)
46
- if any((candidate / marker).exists() for marker in markers):
47
- return candidate
48
-
49
- return script_dir.parent
50
-
51
-
52
- PROJECT_ROOT = guess_project_root(SCRIPT_DIR)
53
  if str(PROJECT_ROOT) not in sys.path:
54
  sys.path.insert(0, str(PROJECT_ROOT))
55
 
56
- OUTPUTS_DIR = PROJECT_ROOT / "outputs"
57
- EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
58
- CONFUSION_MATRIX_PATH = OUTPUTS_DIR / "topic_confusion_matrix.png"
59
-
60
  from huggingface_hub import hf_hub_download
61
 
62
  from src.inference.factory import create_inference_pipeline
63
- from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
64
  from src.utils.logging import configure_logging, get_logger
65
 
66
  configure_logging()
67
  logger = get_logger(__name__)
68
 
69
- _pipeline: InferencePipeline | None = None
 
70
 
71
- VISUALIZATION_DIR = OUTPUTS_DIR
72
- VISUALIZATION_ASSETS: list[tuple[str, str]] = [
73
- ("attention_visualization.png", "Attention weights (single head)"),
74
- ("multihead_attention_visualization.png", "Multi-head attention comparison"),
75
- ("single_vs_multihead.png", "Single vs multi-head attention"),
76
- ("positional_encoding_heatmap.png", "Positional encoding heatmap"),
77
- ]
78
 
79
 
80
- def get_pipeline() -> InferencePipeline:
81
  global _pipeline
82
  if _pipeline is None:
83
- logger.info("Loading inference pipeline ...")
84
-
85
- # Download checkpoint if not found locally
86
  checkpoint_path = Path("checkpoints/best.pt")
87
  if not checkpoint_path.exists():
88
- logger.info("Checkpoint not found locally. Downloading from Hugging Face Hub...")
89
- try:
90
- # Ensure checkpoints directory exists
91
- checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
92
-
93
- # Download from the model repository
94
- # NOTE: Replace 'OliverPerrin/LexiMind-Model' with your actual model repo ID
95
- downloaded_path = hf_hub_download(
96
- repo_id="OliverPerrin/LexiMind-Model",
97
- filename="best.pt",
98
- local_dir="checkpoints",
99
- local_dir_use_symlinks=False,
100
- )
101
- logger.info(f"Checkpoint downloaded to {downloaded_path}")
102
- except Exception as e:
103
- logger.error(f"Failed to download checkpoint: {e}")
104
- # Fallback or re-raise will happen in create_inference_pipeline
105
-
106
  _pipeline, _ = create_inference_pipeline(
107
  tokenizer_dir="artifacts/hf_tokenizer/",
108
  checkpoint_path="checkpoints/best.pt",
109
  labels_path="artifacts/labels.json",
110
  )
111
- logger.info("Pipeline loaded")
112
  return _pipeline
113
 
114
 
115
- def count_tokens(text: str) -> str:
116
- if not text:
117
- return "Tokens: 0"
118
- try:
119
- pipeline = get_pipeline()
120
- return f"Tokens: {len(pipeline.tokenizer.encode(text))}"
121
- except Exception as exc: # pragma: no cover - surfaced in UI
122
- logger.error("Token counting failed: %s", exc, exc_info=True)
123
- return "Token count unavailable"
124
-
125
-
126
- def predict(text: str):
127
  if not text or not text.strip():
128
- return (
129
- "Please enter text to analyze.",
130
- None,
131
- "No topic prediction available.",
132
- None,
133
- )
134
 
135
  try:
136
- pipeline = get_pipeline()
137
- # Fixed max length for simplicity
138
- max_len = 128
139
- logger.info("Generating summary with max length %s", max_len)
140
 
141
- summary = pipeline.summarize([text], max_length=max_len)[0].strip()
142
- # Use a higher threshold to filter out weak/wrong predictions on out-of-domain text
143
- emotions = pipeline.predict_emotions([text], threshold=0.6)[0]
144
- topic = pipeline.predict_topics([text])[0]
145
 
146
- fallback_summary = None
147
- summary_notice = ""
148
- summary_source = summary
149
- if not summary:
150
- fallback_summary = generate_fallback_summary(text)
151
- summary_source = fallback_summary
152
- summary_notice = (
153
- '<p style="color: #b45309; margin-top: 8px;">'
154
- "Model returned an empty summary, so a simple extractive fallback is shown instead."
155
- "</p>"
156
  )
157
-
158
- summary_html = format_summary(text, summary_source, notice=summary_notice)
159
- emotion_plot = create_emotion_plot(emotions)
160
- topic_markdown = format_topic(topic)
161
- heatmap_source = summary if summary else fallback_summary
162
- if heatmap_source:
163
- attention_fig = create_attention_heatmap(text, heatmap_source, pipeline)
164
  else:
165
- attention_fig = render_message_figure(
166
- "Attention heatmap unavailable: summary was empty."
167
- )
168
-
169
- return summary_html, emotion_plot, topic_markdown, attention_fig
170
-
171
- except Exception as exc: # pragma: no cover - surfaced in UI
172
- logger.error("Prediction error: %s", exc, exc_info=True)
173
- return "Prediction failed. Check logs for details.", None, "Error", None
174
-
175
-
176
- def format_summary(original: str, summary: str, *, notice: str = "") -> str:
177
- if not summary:
178
- summary = "(Model returned an empty summary. Consider retraining the summarization head.)"
179
-
180
- return f"""
181
- <div style=\"padding: 12px; border-radius: 6px; background-color: #fafafa; color: #222;\">
182
- <h3 style=\"margin-top: 0; color: #222;\">Original Text</h3>
183
- <p style=\"background-color: #f0f0f0; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #222;\">
184
- {original}
185
- </p>
186
- <h3 style=\"color: #222;\">Summary</h3>
187
- <p style=\"background-color: #e6f3ff; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #111;\">
188
- {summary}
189
- </p>
190
- {notice}
191
- <p style=\"margin-top: 12px; color: #6b7280; font-size: 0.9rem;\">
192
- Outputs are shown exactly as produced by the checkpoint.
193
- </p>
194
- </div>
195
- """.strip()
196
-
197
-
198
- def create_emotion_plot(emotions: EmotionPrediction) -> Figure | None:
199
- if not emotions.labels:
200
- return render_message_figure("No emotions cleared the model threshold.")
201
-
202
- df = pd.DataFrame({"Emotion": emotions.labels, "Probability": emotions.scores}).sort_values(
203
- "Probability", ascending=True
204
- )
205
- fig, ax = plt.subplots(figsize=(6, 4))
206
- colors = sns.color_palette("crest", len(df))
207
- bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
208
- ax.set_xlabel("Probability")
209
- ax.set_title("Emotion Scores")
210
- ax.set_xlim(0, 1)
211
- for bar in bars:
212
- width = bar.get_width()
213
- ax.text(width + 0.02, bar.get_y() + bar.get_height() / 2, f"{width:.2%}", va="center")
214
- plt.tight_layout()
215
- return fig
216
-
217
-
218
- def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
219
- if isinstance(topic, TopicPrediction):
220
- label = topic.label
221
- confidence = topic.confidence
222
- else:
223
- label = str(topic.get("label", "Unknown"))
224
- confidence = float(topic.get("score", 0.0))
225
- return f"""
226
- ### Predicted Topic
227
 
228
- **{label}**
 
 
229
 
230
- Confidence: {confidence:.2%}
231
- """.strip()
232
-
233
-
234
- def _clean_tokens(tokens: Iterable[str]) -> list[str]:
235
- cleaned: list[str] = []
236
- for token in tokens:
237
- item = token.replace("Ġ", " ").replace("▁", " ")
238
- cleaned.append(item.strip() if item.strip() else token)
239
- return cleaned
240
-
241
-
242
- def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
243
- try:
244
- batch = pipeline.preprocessor.batch_encode([text])
245
- batch = pipeline._batch_to_device(batch)
246
- src_ids = batch.input_ids
247
- src_mask = batch.attention_mask
248
- encoder_mask = (
249
- src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
250
- )
251
-
252
- with torch.inference_mode():
253
- memory = pipeline.model.encoder(src_ids, mask=encoder_mask) # type: ignore
254
- target_enc = pipeline.tokenizer.batch_encode([summary])
255
- target_ids = target_enc["input_ids"].to(pipeline.device)
256
- target_mask = target_enc["attention_mask"].to(pipeline.device)
257
- target_len = int(target_mask.sum().item())
258
- decoder_inputs = pipeline.tokenizer.prepare_decoder_inputs(target_ids)
259
- decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
260
- target_ids = target_ids[:, :target_len]
261
- memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
262
- _, attn_list = pipeline.model.decoder( # type: ignore
263
- decoder_inputs,
264
- memory,
265
- memory_mask=memory_mask,
266
- collect_attn=True,
267
- )
268
-
269
- if not attn_list:
270
- return None
271
- cross_attn = attn_list[-1]["cross"]
272
- attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
273
- source_len = batch.lengths[0]
274
- attn_matrix = attn_matrix[:target_len, :source_len]
275
-
276
- source_ids = src_ids[0, :source_len].tolist()
277
- target_id_list = target_ids[0].tolist()
278
-
279
- special_ids = {
280
- pipeline.tokenizer.pad_token_id,
281
- pipeline.tokenizer.bos_token_id,
282
- pipeline.tokenizer.eos_token_id,
283
- }
284
- keep_indices = [
285
- idx for idx, token_id in enumerate(target_id_list) if token_id not in special_ids
286
- ]
287
- if not keep_indices:
288
- return None
289
 
290
- pruned_matrix = attn_matrix[keep_indices, :]
291
- tokenizer_impl = pipeline.tokenizer.tokenizer
292
- convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
293
- if convert_tokens is None:
294
- return None
295
 
296
- summary_tokens_raw = convert_tokens([target_id_list[idx] for idx in keep_indices])
297
- source_tokens_raw = convert_tokens(source_ids)
298
-
299
- summary_tokens = _clean_tokens(summary_tokens_raw)
300
- source_tokens = _clean_tokens(source_tokens_raw)
301
-
302
- # Cap the visualization to prevent massive heatmaps
303
- max_tokens = 40
304
- if len(summary_tokens) > max_tokens:
305
- summary_tokens = summary_tokens[:max_tokens]
306
- pruned_matrix = pruned_matrix[:max_tokens, :]
307
- if len(source_tokens) > max_tokens:
308
- source_tokens = source_tokens[:max_tokens]
309
- pruned_matrix = pruned_matrix[:, :max_tokens]
310
-
311
- height = max(4.0, 0.3 * len(summary_tokens))
312
- width = max(6.0, 0.3 * len(source_tokens))
313
- fig, ax = plt.subplots(figsize=(width, height))
314
- sns.heatmap(
315
- pruned_matrix,
316
- cmap="mako",
317
- xticklabels=source_tokens,
318
- yticklabels=summary_tokens,
319
- ax=ax,
320
- cbar_kws={"label": "Attention"},
321
- )
322
- ax.set_xlabel("Input Tokens")
323
- ax.set_ylabel("Summary Tokens")
324
- ax.set_title("Cross-Attention (decoder last layer)")
325
- ax.tick_params(axis="x", rotation=90)
326
- ax.tick_params(axis="y", rotation=0)
327
- fig.tight_layout()
328
- return fig
329
-
330
- except Exception as exc:
331
- logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
332
- return render_message_figure("Unable to render attention heatmap for this example.")
333
-
334
-
335
- def render_message_figure(message: str) -> Figure:
336
- fig, ax = plt.subplots(figsize=(6, 2))
337
- ax.axis("off")
338
- ax.text(0.5, 0.5, message, ha="center", va="center", wrap=True)
339
- fig.tight_layout()
340
- return fig
341
-
342
-
343
- def prepare_download(
344
- text: str,
345
- summary: str,
346
- emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
347
- topic: TopicPrediction | dict[str, float | str],
348
- *,
349
- neural_summary: str | None = None,
350
- fallback_summary: str | None = None,
351
- ) -> str:
352
- if isinstance(emotions, EmotionPrediction):
353
- emotion_payload = {
354
- "labels": list(emotions.labels),
355
- "scores": list(emotions.scores),
356
- }
357
- else:
358
- emotion_payload = {
359
- "labels": list(emotions.get("labels", [])),
360
- "scores": list(emotions.get("scores", [])),
361
- }
362
-
363
- if isinstance(topic, TopicPrediction):
364
- topic_payload = {"label": topic.label, "confidence": topic.confidence}
365
- else:
366
- topic_payload = {
367
- "label": str(topic.get("label", topic.get("topic", "Unknown"))),
368
- "confidence": float(topic.get("confidence", topic.get("score", 0.0))),
369
- }
370
-
371
- payload = {
372
- "original_text": text,
373
- "summary": summary,
374
- "neural_summary": neural_summary,
375
- "fallback_summary": fallback_summary,
376
- "emotions": emotion_payload,
377
- "topic": topic_payload,
378
- }
379
-
380
- with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
381
- json.dump(payload, handle, ensure_ascii=False, indent=2)
382
- return handle.name
383
-
384
-
385
- def load_visualization_gallery() -> tuple[list[str], str]:
386
- """Collect visualization images produced by model tests."""
387
- items: list[str] = []
388
- missing: list[str] = []
389
- for filename, _label in VISUALIZATION_ASSETS:
390
- path = VISUALIZATION_DIR / filename
391
- if path.exists():
392
- items.append(str(path))
393
- else:
394
- missing.append(filename)
395
-
396
- if items:
397
- status = f"Loaded {len(items)} visualization(s) from {VISUALIZATION_DIR}."
398
- else:
399
- status = (
400
- "No visualization PNGs found in the outputs/ directory. "
401
- "Ensure tests/test_models/* have produced the PNGs and that they are available on this host."
402
- )
403
-
404
- if missing:
405
- status += f" Missing files: {', '.join(missing)}."
406
-
407
- return items, status
408
-
409
-
410
- def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
411
- content = text.strip()
412
- if not content:
413
- return "(Input text was empty.)"
414
-
415
- sentences = re.split(r"(?<=[.!?])\s+", content)
416
- fragments: list[str] = []
417
- total = 0
418
- for sentence in sentences:
419
- if not sentence:
420
- continue
421
- candidate = sentence if sentence.endswith((".", "!", "?")) else f"{sentence}."
422
- if total + len(candidate) > max_chars and fragments:
423
- break
424
- fragments.append(candidate)
425
- total += len(candidate)
426
-
427
- if not fragments:
428
- return content[:max_chars]
429
- return " ".join(fragments)
430
 
431
 
432
- def load_metrics_report_as_markdown() -> tuple[str, str, str | None, str]:
433
- """Load metrics and return as Markdown strings to avoid Gradio schema issues."""
434
  if not EVAL_REPORT_PATH.exists():
435
- error_msg = (
436
- f"Evaluation report not found at {EVAL_REPORT_PATH}. Run scripts/evaluate.py first."
437
- )
438
- return error_msg, "", None, error_msg
439
 
440
  try:
441
- with EVAL_REPORT_PATH.open("r", encoding="utf-8") as handle:
442
- report = json.load(handle)
443
- except Exception as exc:
444
- logger.error("Failed to read evaluation report: %s", exc, exc_info=True)
445
- error_msg = f"Error loading report: {exc}"
446
- return error_msg, "", None, error_msg
447
-
448
- # Build overall metrics markdown table
449
- summary_md = """| Task | Metric | Value |
450
- |------|--------|-------|
451
- | Summarization | ROUGE-Like | {:.4f} |
452
- | Summarization | BLEU | {:.4f} |
453
- | Emotion | F1 (Macro) | {:.4f} |
454
- | Topic | Accuracy | {:.4f} |""".format(
455
- report["summarization"]["rouge_like"],
456
- report["summarization"]["bleu"],
457
- report["emotion"]["f1_macro"],
458
- report["topic"]["accuracy"],
459
- )
460
-
461
- # Build topic classification report markdown table
462
- topic_report = report["topic"]["classification_report"]
463
- topic_lines = [
464
- "| Label | Precision | Recall | F1-Score | Support |",
465
- "|-------|-----------|--------|----------|---------|",
466
- ]
467
- for label, metrics in topic_report.items():
468
- if isinstance(metrics, dict) and "precision" in metrics:
469
- topic_lines.append(
470
- f"| {label} | {metrics['precision']:.4f} | {metrics['recall']:.4f} | {metrics['f1-score']:.4f} | {int(metrics.get('support', 0))} |"
471
- )
472
- topic_md = "\n".join(topic_lines)
473
-
474
- # Confusion Matrix
475
- cm_image = str(CONFUSION_MATRIX_PATH) if CONFUSION_MATRIX_PATH.exists() else None
476
-
477
- # Metadata
478
- meta_str = f"Split: {report.get('split', 'unknown')}\nLast updated: {datetime.fromtimestamp(EVAL_REPORT_PATH.stat().st_mtime).isoformat()}"
479
-
480
- return summary_md, topic_md, cm_image, meta_str
481
-
482
-
483
- SAMPLE_TEXT = (
484
- "Artificial intelligence is rapidly transforming the technology landscape. "
485
- "Machine learning algorithms are now capable of processing vast amounts of data, "
486
- "identifying patterns, and making predictions with unprecedented accuracy. "
487
- "From healthcare diagnostics to financial forecasting, AI applications are "
488
- "revolutionizing industries worldwide. However, ethical considerations around "
489
- "privacy, bias, and transparency remain critical challenges that must be addressed "
490
- "as these technologies continue to evolve."
491
- )
492
-
493
-
494
- def create_interface() -> gr.Blocks:
495
- with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
496
- gr.Markdown(
497
- """
498
- # LexiMind NLP Demo
499
-
500
- This demo streams the raw outputs from the saved LexiMind checkpoint.
501
- Results may be noisy; retraining is recommended for production use.
502
- """
503
- )
504
-
505
- _, initial_visual_status = load_visualization_gallery()
506
- summary_md, topic_md, cm_image, metrics_meta = load_metrics_report_as_markdown()
507
-
508
- with gr.Row():
509
- with gr.Column(scale=1):
510
- input_text = gr.Textbox(
511
- label="Input Text",
512
- lines=10,
513
- value=SAMPLE_TEXT,
514
- placeholder="Paste or type your text here...",
515
  )
516
- token_box = gr.Textbox(label="Token Count", value="Tokens: 0", interactive=False)
517
- analyze_btn = gr.Button("Run Analysis", variant="primary")
518
 
519
- with gr.Column(scale=2):
520
- with gr.Tabs():
521
- with gr.TabItem("Summary"):
522
- summary_output = gr.HTML(label="Summary")
523
- with gr.TabItem("Emotions"):
524
- emotion_output = gr.Plot(label="Emotion Probabilities")
525
- with gr.TabItem("Topic"):
526
- topic_output = gr.Markdown(label="Topic Prediction")
527
- with gr.TabItem("Attention"):
528
- attention_output = gr.Plot(label="Attention Heatmap")
529
- gr.Markdown("*Shows decoder attention if a summary is available.*")
530
- with gr.TabItem("Model Performance"):
531
- gr.Markdown("### Overall Metrics")
532
- metrics_table = gr.Markdown(value=summary_md)
533
- gr.Markdown("### Topic Classification Report")
534
- topic_table = gr.Markdown(value=topic_md)
535
- gr.Markdown("### Topic Confusion Matrix")
536
- cm_output = gr.Image(value=cm_image, label="Confusion Matrix")
537
- gr.Markdown("### Metadata")
538
- metrics_meta_text = gr.Textbox(
539
- value=metrics_meta,
540
- label="Info",
541
- interactive=False,
542
- lines=2,
543
- )
544
- refresh_metrics = gr.Button("Refresh Metrics")
545
 
546
- with gr.TabItem("Model Visuals"):
547
- gr.Markdown(
548
- "These PNGs come from the visualization-focused tests in `tests/test_models` and are consumed as-is."
549
- )
550
- gr.Markdown(initial_visual_status)
551
 
552
- input_text.change(fn=count_tokens, inputs=[input_text], outputs=[token_box])
553
- analyze_btn.click(
554
- fn=predict,
555
- inputs=[input_text],
556
- outputs=[summary_output, emotion_output, topic_output, attention_output],
557
- )
558
- refresh_metrics.click(
559
- fn=load_metrics_report_as_markdown,
560
- inputs=None,
561
- outputs=[metrics_table, topic_table, cm_output, metrics_meta_text],
562
- )
563
- return demo
564
 
 
 
 
 
565
 
566
- demo = create_interface()
567
- app = demo
 
 
 
568
 
 
 
569
 
570
  if __name__ == "__main__":
571
- import os
572
-
573
- # On HuggingFace Spaces, share must be False (they handle routing)
574
- # but we need to ensure server binds correctly
575
- is_hf_space = os.environ.get("SPACE_ID") is not None
576
-
577
- try:
578
- get_pipeline()
579
- demo.queue().launch(
580
- server_name="0.0.0.0",
581
- server_port=7860,
582
- share=False,
583
- allowed_paths=[str(OUTPUTS_DIR)],
584
- )
585
- except Exception as exc: # pragma: no cover - surfaced in console
586
- logger.error("Failed to launch demo: %s", exc, exc_info=True)
587
- raise
 
1
+ """Minimal Gradio demo for LexiMind multitask model."""
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import json
 
6
  import sys
 
7
  from pathlib import Path
 
 
8
 
9
  import gradio as gr
 
 
 
 
 
 
10
 
 
11
  SCRIPT_DIR = Path(__file__).resolve().parent
12
+ PROJECT_ROOT = SCRIPT_DIR.parent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if str(PROJECT_ROOT) not in sys.path:
14
  sys.path.insert(0, str(PROJECT_ROOT))
15
 
 
 
 
 
16
  from huggingface_hub import hf_hub_download
17
 
18
  from src.inference.factory import create_inference_pipeline
 
19
  from src.utils.logging import configure_logging, get_logger
20
 
21
  configure_logging()
22
  logger = get_logger(__name__)
23
 
24
+ OUTPUTS_DIR = PROJECT_ROOT / "outputs"
25
+ EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
26
 
27
+ _pipeline = None
 
 
 
 
 
 
28
 
29
 
30
+ def get_pipeline():
31
  global _pipeline
32
  if _pipeline is None:
 
 
 
33
  checkpoint_path = Path("checkpoints/best.pt")
34
  if not checkpoint_path.exists():
35
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
36
+ hf_hub_download(
37
+ repo_id="OliverPerrin/LexiMind-Model",
38
+ filename="best.pt",
39
+ local_dir="checkpoints",
40
+ local_dir_use_symlinks=False,
41
+ )
 
 
 
 
 
 
 
 
 
 
 
42
  _pipeline, _ = create_inference_pipeline(
43
  tokenizer_dir="artifacts/hf_tokenizer/",
44
  checkpoint_path="checkpoints/best.pt",
45
  labels_path="artifacts/labels.json",
46
  )
 
47
  return _pipeline
48
 
49
 
50
+ def analyze(text: str) -> str:
51
+ """Run all three tasks and return results as formatted text."""
 
 
 
 
 
 
 
 
 
 
52
  if not text or not text.strip():
53
+ return "Please enter some text to analyze."
 
 
 
 
 
54
 
55
  try:
56
+ pipe = get_pipeline()
 
 
 
57
 
58
+ # Summarization
59
+ summary = pipe.summarize([text], max_length=128)[0].strip() or "(empty)"
 
 
60
 
61
+ # Emotion detection
62
+ emotions = pipe.predict_emotions([text], threshold=0.5)[0]
63
+ if emotions.labels:
64
+ emotion_str = ", ".join(
65
+ f"{lbl} ({score:.1%})"
66
+ for lbl, score in zip(emotions.labels, emotions.scores, strict=True)
 
 
 
 
67
  )
 
 
 
 
 
 
 
68
  else:
69
+ emotion_str = "No strong emotions detected"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # Topic classification
72
+ topic = pipe.predict_topics([text])[0]
73
+ topic_str = f"{topic.label} ({topic.confidence:.1%})"
74
 
75
+ return f"""## Summary
76
+ {summary}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ ## Detected Emotions
79
+ {emotion_str}
 
 
 
80
 
81
+ ## Topic
82
+ {topic_str}
83
+ """
84
+ except Exception as e:
85
+ logger.error("Analysis failed: %s", e, exc_info=True)
86
+ return f"Error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
+ def get_metrics() -> str:
90
+ """Load evaluation metrics as markdown."""
91
  if not EVAL_REPORT_PATH.exists():
92
+ return "No evaluation report found. Run `scripts/evaluate.py` first."
 
 
 
93
 
94
  try:
95
+ with open(EVAL_REPORT_PATH) as f:
96
+ r = json.load(f)
97
+
98
+ lines = [
99
+ "## Model Performance\n",
100
+ "| Task | Metric | Score |",
101
+ "|------|--------|-------|",
102
+ f"| Summarization | ROUGE-Like | {r['summarization']['rouge_like']:.4f} |",
103
+ f"| Summarization | BLEU | {r['summarization']['bleu']:.4f} |",
104
+ f"| Emotion | F1 Macro | {r['emotion']['f1_macro']:.4f} |",
105
+ f"| Topic | Accuracy | {r['topic']['accuracy']:.4f} |",
106
+ "",
107
+ "### Topic Classification Details\n",
108
+ "| Label | Precision | Recall | F1 |",
109
+ "|-------|-----------|--------|-----|",
110
+ ]
111
+ for k, v in r["topic"]["classification_report"].items():
112
+ if isinstance(v, dict) and "precision" in v:
113
+ lines.append(
114
+ f"| {k} | {v['precision']:.3f} | {v['recall']:.3f} | {v['f1-score']:.3f} |"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  )
 
 
116
 
117
+ return "\n".join(lines)
118
+ except Exception as e:
119
+ return f"Error loading metrics: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
 
 
 
 
121
 
122
+ SAMPLE = """Artificial intelligence is rapidly transforming technology. Machine learning algorithms process vast amounts of data, identifying patterns with unprecedented accuracy. From healthcare to finance, AI is revolutionizing industries worldwide. However, ethical considerations around privacy and bias remain critical challenges."""
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ with gr.Blocks(title="LexiMind Demo") as demo:
125
+ gr.Markdown(
126
+ "# LexiMind NLP Demo\nMulti-task model: summarization, emotion detection, topic classification."
127
+ )
128
 
129
+ with gr.Tab("Analyze"):
130
+ text_input = gr.Textbox(label="Input Text", lines=6, value=SAMPLE)
131
+ analyze_btn = gr.Button("Analyze", variant="primary")
132
+ output = gr.Markdown(label="Results")
133
+ analyze_btn.click(fn=analyze, inputs=text_input, outputs=output)
134
 
135
+ with gr.Tab("Metrics"):
136
+ gr.Markdown(get_metrics())
137
 
138
  if __name__ == "__main__":
139
+ get_pipeline() # Pre-load
140
+ demo.launch(server_name="0.0.0.0", server_port=7860)