OliverPerrin commited on
Commit
00d412c
·
1 Parent(s): a504116

Fixed Gradio Summarization Issue

Browse files
requirements-dev.txt CHANGED
@@ -7,4 +7,5 @@ flake8>=6.0.0
7
  mypy>=1.4.0
8
  jupyter>=1.0.0
9
  ipywidgets>=8.0.0
10
- pre-commit>=3.4.0
 
 
7
  mypy>=1.4.0
8
  jupyter>=1.0.0
9
  ipywidgets>=8.0.0
10
+ pre-commit>=3.4.0
11
+ rouge-score>=0.1.2
requirements.txt CHANGED
@@ -11,4 +11,5 @@ datasets>=4.4.0
11
  gradio>=4.0.0
12
  seaborn
13
  pytest
14
- matplotlib
 
 
11
  gradio>=4.0.0
12
  seaborn
13
  pytest
14
+ matplotlib
15
+ rouge-score>=0.1.2
scripts/demo_gradio.py CHANGED
@@ -24,6 +24,8 @@ PROJECT_ROOT = Path(__file__).resolve().parent.parent
24
  if str(PROJECT_ROOT) not in sys.path:
25
  sys.path.insert(0, str(PROJECT_ROOT))
26
 
 
 
27
  from src.inference.factory import create_inference_pipeline
28
  from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
29
  from src.utils.logging import configure_logging, get_logger
@@ -358,6 +360,39 @@ def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
358
  return " ".join(fragments)
359
 
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  SAMPLE_TEXT = (
362
  "Artificial intelligence is rapidly transforming the technology landscape. "
363
  "Machine learning algorithms are now capable of processing vast amounts of data, "
@@ -380,6 +415,8 @@ def create_interface() -> gr.Blocks:
380
  """
381
  )
382
 
 
 
383
  with gr.Row():
384
  with gr.Column(scale=1):
385
  input_text = gr.Textbox(
@@ -417,11 +454,25 @@ def create_interface() -> gr.Blocks:
417
  columns=2,
418
  height=400,
419
  interactive=False,
 
420
  )
421
  gr.Markdown(
422
  "These PNGs come from the visualization-focused tests in `tests/test_models` and are consumed as-is."
423
  )
424
  refresh_visuals = gr.Button("Refresh Visuals")
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  gr.Markdown("### Download Results")
426
  download_btn = gr.DownloadButton("Download JSON", visible=False)
427
 
@@ -432,6 +483,7 @@ def create_interface() -> gr.Blocks:
432
  outputs=[summary_output, emotion_output, topic_output, attention_output, download_btn],
433
  )
434
  refresh_visuals.click(fn=load_visualization_gallery, inputs=None, outputs=visuals)
 
435
  return demo
436
 
437
 
 
24
  if str(PROJECT_ROOT) not in sys.path:
25
  sys.path.insert(0, str(PROJECT_ROOT))
26
 
27
+ ROUGE_REPORT_PATH = PROJECT_ROOT / "outputs" / "rouge_validation.json"
28
+
29
  from src.inference.factory import create_inference_pipeline
30
  from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
31
  from src.utils.logging import configure_logging, get_logger
 
360
  return " ".join(fragments)
361
 
362
 
363
+ def load_rouge_metrics():
364
+ columns = ["metric", "precision", "recall", "fmeasure"]
365
+ empty = pd.DataFrame(columns=columns)
366
+ if not ROUGE_REPORT_PATH.exists():
367
+ return empty, {"error": f"ROUGE report not found at {ROUGE_REPORT_PATH}"}
368
+
369
+ try:
370
+ with ROUGE_REPORT_PATH.open("r", encoding="utf-8") as handle:
371
+ report = json.load(handle)
372
+ except Exception as exc: # pragma: no cover - surfaced in UI
373
+ logger.error("Failed to read ROUGE report: %s", exc, exc_info=True)
374
+ return empty, {"error": f"Unable to parse report: {exc}", "report_path": str(ROUGE_REPORT_PATH)}
375
+
376
+ rows: list[dict[str, object]] = []
377
+ for metric_name, components in report.get("metrics", {}).items():
378
+ rows.append(
379
+ {
380
+ "metric": metric_name,
381
+ "precision": round(float(components.get("precision", 0.0)), 4),
382
+ "recall": round(float(components.get("recall", 0.0)), 4),
383
+ "fmeasure": round(float(components.get("fmeasure", 0.0)), 4),
384
+ }
385
+ )
386
+
387
+ table = pd.DataFrame(rows, columns=columns) if rows else empty
388
+ metadata = {
389
+ "num_examples": report.get("num_examples"),
390
+ "config": report.get("config"),
391
+ "report_path": str(ROUGE_REPORT_PATH),
392
+ }
393
+ return table, metadata
394
+
395
+
396
  SAMPLE_TEXT = (
397
  "Artificial intelligence is rapidly transforming the technology landscape. "
398
  "Machine learning algorithms are now capable of processing vast amounts of data, "
 
415
  """
416
  )
417
 
418
+ initial_metrics, initial_metrics_meta = load_rouge_metrics()
419
+
420
  with gr.Row():
421
  with gr.Column(scale=1):
422
  input_text = gr.Textbox(
 
454
  columns=2,
455
  height=400,
456
  interactive=False,
457
+ type="filepath"
458
  )
459
  gr.Markdown(
460
  "These PNGs come from the visualization-focused tests in `tests/test_models` and are consumed as-is."
461
  )
462
  refresh_visuals = gr.Button("Refresh Visuals")
463
+ with gr.TabItem("Metrics"):
464
+ rouge_table = gr.Dataframe(
465
+ value=initial_metrics,
466
+ headers=["metric", "precision", "recall", "fmeasure"],
467
+ datatype=["str", "number", "number", "number"],
468
+ interactive=False,
469
+ label="ROUGE Scores",
470
+ )
471
+ rouge_meta = gr.JSON(
472
+ value=initial_metrics_meta,
473
+ label="ROUGE Run Metadata",
474
+ )
475
+ refresh_metrics = gr.Button("Refresh Metrics")
476
  gr.Markdown("### Download Results")
477
  download_btn = gr.DownloadButton("Download JSON", visible=False)
478
 
 
483
  outputs=[summary_output, emotion_output, topic_output, attention_output, download_btn],
484
  )
485
  refresh_visuals.click(fn=load_visualization_gallery, inputs=None, outputs=visuals)
486
+ refresh_metrics.click(fn=load_rouge_metrics, inputs=None, outputs=[rouge_table, rouge_meta])
487
  return demo
488
 
489
 
scripts/eval_rouge.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility script to evaluate LexiMind summaries with ROUGE."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ from collections import defaultdict
7
+ from pathlib import Path
8
+ from statistics import fmean
9
+ from typing import Dict, Iterable, List, Sequence, Tuple
10
+
11
+ import sys
12
+
13
+ from rouge_score import rouge_scorer
14
+ from tqdm import tqdm
15
+
16
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
17
+ if str(PROJECT_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(PROJECT_ROOT))
19
+
20
+ from src.inference.factory import create_inference_pipeline
21
+
22
+
23
+ def parse_args() -> argparse.Namespace:
24
+ parser = argparse.ArgumentParser(description="Evaluate LexiMind summaries with ROUGE metrics.")
25
+ parser.add_argument("data", type=Path, help="Path to JSONL file with source text and gold summaries.")
26
+ parser.add_argument("checkpoint", type=Path, help="Path to the trained checkpoint (e.g., checkpoints/best.pt).")
27
+ parser.add_argument("labels", type=Path, help="Path to label metadata (e.g., artifacts/labels.json).")
28
+ parser.add_argument(
29
+ "--tokenizer-dir",
30
+ type=Path,
31
+ default=Path("artifacts/hf_tokenizer"),
32
+ help="Directory containing the saved tokenizer artifacts.",
33
+ )
34
+ parser.add_argument(
35
+ "--model-config",
36
+ type=Path,
37
+ default=None,
38
+ help="Optional YAML config describing the model architecture.",
39
+ )
40
+ parser.add_argument("--device", type=str, default="cpu", help="Device to run inference on (cpu or cuda).")
41
+ parser.add_argument("--batch-size", type=int, default=8, help="Number of samples per inference batch.")
42
+ parser.add_argument(
43
+ "--max-samples",
44
+ type=int,
45
+ default=None,
46
+ help="If provided, limit evaluation to the first N samples for quick smoke tests.",
47
+ )
48
+ parser.add_argument(
49
+ "--max-length",
50
+ type=int,
51
+ default=128,
52
+ help="Maximum length to pass into the summarization head during generation.",
53
+ )
54
+ parser.add_argument(
55
+ "--metrics",
56
+ type=str,
57
+ nargs="+",
58
+ default=("rouge1", "rouge2", "rougeL"),
59
+ help="ROUGE metrics to compute.",
60
+ )
61
+ parser.add_argument(
62
+ "--source-field",
63
+ type=str,
64
+ default="source",
65
+ help="Field name containing the input document in the JSONL examples.",
66
+ )
67
+ parser.add_argument(
68
+ "--target-field",
69
+ type=str,
70
+ default="summary",
71
+ help="Field name containing the reference summary in the JSONL examples.",
72
+ )
73
+ parser.add_argument(
74
+ "--no-stemmer",
75
+ action="store_true",
76
+ help="Disable Porter stemming inside the ROUGE scorer (defaults to enabled).",
77
+ )
78
+ parser.add_argument(
79
+ "--output",
80
+ type=Path,
81
+ default=None,
82
+ help="Optional path to save a JSON report with aggregate metrics and sample counts.",
83
+ )
84
+ return parser.parse_args()
85
+
86
+
87
+ def load_examples(
88
+ path: Path,
89
+ source_field: str,
90
+ target_field: str,
91
+ max_samples: int | None,
92
+ ) -> List[Tuple[str, str]]:
93
+ examples: List[Tuple[str, str]] = []
94
+ with path.open("r", encoding="utf-8") as handle:
95
+ for line in handle:
96
+ line = line.strip()
97
+ if not line:
98
+ continue
99
+ record = json.loads(line)
100
+ try:
101
+ source = str(record[source_field])
102
+ target = str(record[target_field])
103
+ except KeyError as exc: # pragma: no cover - invalid data surface at runtime
104
+ raise KeyError(f"Missing field in record: {exc} (available keys: {list(record)})") from exc
105
+ examples.append((source, target))
106
+ if max_samples is not None and len(examples) >= max_samples:
107
+ break
108
+ if not examples:
109
+ raise ValueError(f"No examples loaded from {path}")
110
+ return examples
111
+
112
+
113
+ def batched(items: Sequence[Tuple[str, str]], batch_size: int) -> Iterable[Sequence[Tuple[str, str]]]:
114
+ for start in range(0, len(items), batch_size):
115
+ yield items[start : start + batch_size]
116
+
117
+
118
+ def aggregate_scores(raw_scores: Dict[str, Dict[str, List[float]]]) -> Dict[str, Dict[str, float]]:
119
+ aggregated: Dict[str, Dict[str, float]] = {}
120
+ for metric, components in raw_scores.items():
121
+ aggregated[metric] = {
122
+ component: (fmean(values) if values else 0.0) for component, values in components.items()
123
+ }
124
+ return aggregated
125
+
126
+
127
+ def main() -> None:
128
+ args = parse_args()
129
+
130
+ pipeline, _ = create_inference_pipeline(
131
+ checkpoint_path=args.checkpoint,
132
+ labels_path=args.labels,
133
+ tokenizer_dir=args.tokenizer_dir,
134
+ model_config_path=args.model_config,
135
+ device=args.device,
136
+ summary_max_length=args.max_length,
137
+ )
138
+
139
+ examples = load_examples(args.data, args.source_field, args.target_field, args.max_samples)
140
+ scorer = rouge_scorer.RougeScorer(list(args.metrics), use_stemmer=not args.no_stemmer)
141
+
142
+ score_store: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
143
+
144
+ for batch in tqdm(
145
+ list(batched(examples, args.batch_size)),
146
+ desc="Evaluating",
147
+ total=(len(examples) + args.batch_size - 1) // args.batch_size,
148
+ ):
149
+ documents = [item[0] for item in batch]
150
+ references = [item[1] for item in batch]
151
+ predictions = pipeline.summarize(documents, max_length=args.max_length)
152
+
153
+ for reference, prediction in zip(references, predictions):
154
+ scores = scorer.score(reference, prediction)
155
+ for metric_name, score in scores.items():
156
+ score_store[metric_name]["precision"].append(score.precision)
157
+ score_store[metric_name]["recall"].append(score.recall)
158
+ score_store[metric_name]["fmeasure"].append(score.fmeasure)
159
+
160
+ aggregated = aggregate_scores(score_store)
161
+ report = {
162
+ "num_examples": len(examples),
163
+ "metrics": aggregated,
164
+ "config": {
165
+ "data": str(args.data),
166
+ "checkpoint": str(args.checkpoint),
167
+ "tokenizer_dir": str(args.tokenizer_dir),
168
+ "metrics": list(args.metrics),
169
+ "max_length": args.max_length,
170
+ "batch_size": args.batch_size,
171
+ "device": args.device,
172
+ },
173
+ }
174
+
175
+ print(json.dumps(report, indent=2))
176
+ if args.output:
177
+ args.output.parent.mkdir(parents=True, exist_ok=True)
178
+ with args.output.open("w", encoding="utf-8") as handle:
179
+ json.dump(report, handle, ensure_ascii=False, indent=2)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()