Spaces:
Running
Running
OliverPerrin
commited on
Commit
·
18fc263
1
Parent(s):
5fde4fb
Simplify Gradio demo to fix schema bug
Browse files- scripts/demo_gradio.py +79 -526
scripts/demo_gradio.py
CHANGED
|
@@ -1,587 +1,140 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Minimal Gradio demo for the LexiMind multitask model.
|
| 3 |
-
Shows raw model outputs without any post-processing tricks.
|
| 4 |
-
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import json
|
| 9 |
-
import re
|
| 10 |
import sys
|
| 11 |
-
from datetime import datetime
|
| 12 |
from pathlib import Path
|
| 13 |
-
from tempfile import NamedTemporaryFile
|
| 14 |
-
from typing import Iterable, Sequence
|
| 15 |
|
| 16 |
import gradio as gr
|
| 17 |
-
import matplotlib.pyplot as plt
|
| 18 |
-
import pandas as pd
|
| 19 |
-
import seaborn as sns
|
| 20 |
-
import torch
|
| 21 |
-
from gradio.themes import Soft
|
| 22 |
-
from matplotlib.figure import Figure
|
| 23 |
|
| 24 |
-
# Make local packages importable when running the script directly
|
| 25 |
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def guess_project_root(script_dir: Path) -> Path:
|
| 29 |
-
"""Attempt to locate the LexiMind repo root even when deployed under /app."""
|
| 30 |
-
markers = ("pyproject.toml", "setup.py", "README.md")
|
| 31 |
-
candidates = [script_dir]
|
| 32 |
-
candidates.extend(script_dir.parents)
|
| 33 |
-
candidates.extend(
|
| 34 |
-
[
|
| 35 |
-
script_dir / "LexiMind",
|
| 36 |
-
script_dir.parent / "LexiMind",
|
| 37 |
-
Path("/LexiMind"),
|
| 38 |
-
]
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
seen: set[Path] = set()
|
| 42 |
-
for candidate in candidates:
|
| 43 |
-
if candidate in seen:
|
| 44 |
-
continue
|
| 45 |
-
seen.add(candidate)
|
| 46 |
-
if any((candidate / marker).exists() for marker in markers):
|
| 47 |
-
return candidate
|
| 48 |
-
|
| 49 |
-
return script_dir.parent
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
PROJECT_ROOT = guess_project_root(SCRIPT_DIR)
|
| 53 |
if str(PROJECT_ROOT) not in sys.path:
|
| 54 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 55 |
|
| 56 |
-
OUTPUTS_DIR = PROJECT_ROOT / "outputs"
|
| 57 |
-
EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
|
| 58 |
-
CONFUSION_MATRIX_PATH = OUTPUTS_DIR / "topic_confusion_matrix.png"
|
| 59 |
-
|
| 60 |
from huggingface_hub import hf_hub_download
|
| 61 |
|
| 62 |
from src.inference.factory import create_inference_pipeline
|
| 63 |
-
from src.inference.pipeline import EmotionPrediction, InferencePipeline, TopicPrediction
|
| 64 |
from src.utils.logging import configure_logging, get_logger
|
| 65 |
|
| 66 |
configure_logging()
|
| 67 |
logger = get_logger(__name__)
|
| 68 |
|
| 69 |
-
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
VISUALIZATION_ASSETS: list[tuple[str, str]] = [
|
| 73 |
-
("attention_visualization.png", "Attention weights (single head)"),
|
| 74 |
-
("multihead_attention_visualization.png", "Multi-head attention comparison"),
|
| 75 |
-
("single_vs_multihead.png", "Single vs multi-head attention"),
|
| 76 |
-
("positional_encoding_heatmap.png", "Positional encoding heatmap"),
|
| 77 |
-
]
|
| 78 |
|
| 79 |
|
| 80 |
-
def get_pipeline()
|
| 81 |
global _pipeline
|
| 82 |
if _pipeline is None:
|
| 83 |
-
logger.info("Loading inference pipeline ...")
|
| 84 |
-
|
| 85 |
-
# Download checkpoint if not found locally
|
| 86 |
checkpoint_path = Path("checkpoints/best.pt")
|
| 87 |
if not checkpoint_path.exists():
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
downloaded_path = hf_hub_download(
|
| 96 |
-
repo_id="OliverPerrin/LexiMind-Model",
|
| 97 |
-
filename="best.pt",
|
| 98 |
-
local_dir="checkpoints",
|
| 99 |
-
local_dir_use_symlinks=False,
|
| 100 |
-
)
|
| 101 |
-
logger.info(f"Checkpoint downloaded to {downloaded_path}")
|
| 102 |
-
except Exception as e:
|
| 103 |
-
logger.error(f"Failed to download checkpoint: {e}")
|
| 104 |
-
# Fallback or re-raise will happen in create_inference_pipeline
|
| 105 |
-
|
| 106 |
_pipeline, _ = create_inference_pipeline(
|
| 107 |
tokenizer_dir="artifacts/hf_tokenizer/",
|
| 108 |
checkpoint_path="checkpoints/best.pt",
|
| 109 |
labels_path="artifacts/labels.json",
|
| 110 |
)
|
| 111 |
-
logger.info("Pipeline loaded")
|
| 112 |
return _pipeline
|
| 113 |
|
| 114 |
|
| 115 |
-
def
|
| 116 |
-
|
| 117 |
-
return "Tokens: 0"
|
| 118 |
-
try:
|
| 119 |
-
pipeline = get_pipeline()
|
| 120 |
-
return f"Tokens: {len(pipeline.tokenizer.encode(text))}"
|
| 121 |
-
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 122 |
-
logger.error("Token counting failed: %s", exc, exc_info=True)
|
| 123 |
-
return "Token count unavailable"
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def predict(text: str):
|
| 127 |
if not text or not text.strip():
|
| 128 |
-
return
|
| 129 |
-
"Please enter text to analyze.",
|
| 130 |
-
None,
|
| 131 |
-
"No topic prediction available.",
|
| 132 |
-
None,
|
| 133 |
-
)
|
| 134 |
|
| 135 |
try:
|
| 136 |
-
|
| 137 |
-
# Fixed max length for simplicity
|
| 138 |
-
max_len = 128
|
| 139 |
-
logger.info("Generating summary with max length %s", max_len)
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
emotions = pipeline.predict_emotions([text], threshold=0.6)[0]
|
| 144 |
-
topic = pipeline.predict_topics([text])[0]
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
summary_notice = (
|
| 153 |
-
'<p style="color: #b45309; margin-top: 8px;">'
|
| 154 |
-
"Model returned an empty summary, so a simple extractive fallback is shown instead."
|
| 155 |
-
"</p>"
|
| 156 |
)
|
| 157 |
-
|
| 158 |
-
summary_html = format_summary(text, summary_source, notice=summary_notice)
|
| 159 |
-
emotion_plot = create_emotion_plot(emotions)
|
| 160 |
-
topic_markdown = format_topic(topic)
|
| 161 |
-
heatmap_source = summary if summary else fallback_summary
|
| 162 |
-
if heatmap_source:
|
| 163 |
-
attention_fig = create_attention_heatmap(text, heatmap_source, pipeline)
|
| 164 |
else:
|
| 165 |
-
|
| 166 |
-
"Attention heatmap unavailable: summary was empty."
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
return summary_html, emotion_plot, topic_markdown, attention_fig
|
| 170 |
-
|
| 171 |
-
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 172 |
-
logger.error("Prediction error: %s", exc, exc_info=True)
|
| 173 |
-
return "Prediction failed. Check logs for details.", None, "Error", None
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def format_summary(original: str, summary: str, *, notice: str = "") -> str:
|
| 177 |
-
if not summary:
|
| 178 |
-
summary = "(Model returned an empty summary. Consider retraining the summarization head.)"
|
| 179 |
-
|
| 180 |
-
return f"""
|
| 181 |
-
<div style=\"padding: 12px; border-radius: 6px; background-color: #fafafa; color: #222;\">
|
| 182 |
-
<h3 style=\"margin-top: 0; color: #222;\">Original Text</h3>
|
| 183 |
-
<p style=\"background-color: #f0f0f0; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #222;\">
|
| 184 |
-
{original}
|
| 185 |
-
</p>
|
| 186 |
-
<h3 style=\"color: #222;\">Summary</h3>
|
| 187 |
-
<p style=\"background-color: #e6f3ff; padding: 10px; border-radius: 4px; white-space: pre-wrap; color: #111;\">
|
| 188 |
-
{summary}
|
| 189 |
-
</p>
|
| 190 |
-
{notice}
|
| 191 |
-
<p style=\"margin-top: 12px; color: #6b7280; font-size: 0.9rem;\">
|
| 192 |
-
Outputs are shown exactly as produced by the checkpoint.
|
| 193 |
-
</p>
|
| 194 |
-
</div>
|
| 195 |
-
""".strip()
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def create_emotion_plot(emotions: EmotionPrediction) -> Figure | None:
|
| 199 |
-
if not emotions.labels:
|
| 200 |
-
return render_message_figure("No emotions cleared the model threshold.")
|
| 201 |
-
|
| 202 |
-
df = pd.DataFrame({"Emotion": emotions.labels, "Probability": emotions.scores}).sort_values(
|
| 203 |
-
"Probability", ascending=True
|
| 204 |
-
)
|
| 205 |
-
fig, ax = plt.subplots(figsize=(6, 4))
|
| 206 |
-
colors = sns.color_palette("crest", len(df))
|
| 207 |
-
bars = ax.barh(df["Emotion"], df["Probability"], color=colors)
|
| 208 |
-
ax.set_xlabel("Probability")
|
| 209 |
-
ax.set_title("Emotion Scores")
|
| 210 |
-
ax.set_xlim(0, 1)
|
| 211 |
-
for bar in bars:
|
| 212 |
-
width = bar.get_width()
|
| 213 |
-
ax.text(width + 0.02, bar.get_y() + bar.get_height() / 2, f"{width:.2%}", va="center")
|
| 214 |
-
plt.tight_layout()
|
| 215 |
-
return fig
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
def format_topic(topic: TopicPrediction | dict[str, float | str]) -> str:
|
| 219 |
-
if isinstance(topic, TopicPrediction):
|
| 220 |
-
label = topic.label
|
| 221 |
-
confidence = topic.confidence
|
| 222 |
-
else:
|
| 223 |
-
label = str(topic.get("label", "Unknown"))
|
| 224 |
-
confidence = float(topic.get("score", 0.0))
|
| 225 |
-
return f"""
|
| 226 |
-
### Predicted Topic
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
def _clean_tokens(tokens: Iterable[str]) -> list[str]:
|
| 235 |
-
cleaned: list[str] = []
|
| 236 |
-
for token in tokens:
|
| 237 |
-
item = token.replace("Ġ", " ").replace("▁", " ")
|
| 238 |
-
cleaned.append(item.strip() if item.strip() else token)
|
| 239 |
-
return cleaned
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def create_attention_heatmap(text: str, summary: str, pipeline: InferencePipeline) -> Figure | None:
|
| 243 |
-
try:
|
| 244 |
-
batch = pipeline.preprocessor.batch_encode([text])
|
| 245 |
-
batch = pipeline._batch_to_device(batch)
|
| 246 |
-
src_ids = batch.input_ids
|
| 247 |
-
src_mask = batch.attention_mask
|
| 248 |
-
encoder_mask = (
|
| 249 |
-
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
with torch.inference_mode():
|
| 253 |
-
memory = pipeline.model.encoder(src_ids, mask=encoder_mask) # type: ignore
|
| 254 |
-
target_enc = pipeline.tokenizer.batch_encode([summary])
|
| 255 |
-
target_ids = target_enc["input_ids"].to(pipeline.device)
|
| 256 |
-
target_mask = target_enc["attention_mask"].to(pipeline.device)
|
| 257 |
-
target_len = int(target_mask.sum().item())
|
| 258 |
-
decoder_inputs = pipeline.tokenizer.prepare_decoder_inputs(target_ids)
|
| 259 |
-
decoder_inputs = decoder_inputs[:, :target_len].to(pipeline.device)
|
| 260 |
-
target_ids = target_ids[:, :target_len]
|
| 261 |
-
memory_mask = src_mask.to(pipeline.device) if src_mask is not None else None
|
| 262 |
-
_, attn_list = pipeline.model.decoder( # type: ignore
|
| 263 |
-
decoder_inputs,
|
| 264 |
-
memory,
|
| 265 |
-
memory_mask=memory_mask,
|
| 266 |
-
collect_attn=True,
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
if not attn_list:
|
| 270 |
-
return None
|
| 271 |
-
cross_attn = attn_list[-1]["cross"]
|
| 272 |
-
attn_matrix = cross_attn.mean(dim=1)[0].detach().cpu().numpy()
|
| 273 |
-
source_len = batch.lengths[0]
|
| 274 |
-
attn_matrix = attn_matrix[:target_len, :source_len]
|
| 275 |
-
|
| 276 |
-
source_ids = src_ids[0, :source_len].tolist()
|
| 277 |
-
target_id_list = target_ids[0].tolist()
|
| 278 |
-
|
| 279 |
-
special_ids = {
|
| 280 |
-
pipeline.tokenizer.pad_token_id,
|
| 281 |
-
pipeline.tokenizer.bos_token_id,
|
| 282 |
-
pipeline.tokenizer.eos_token_id,
|
| 283 |
-
}
|
| 284 |
-
keep_indices = [
|
| 285 |
-
idx for idx, token_id in enumerate(target_id_list) if token_id not in special_ids
|
| 286 |
-
]
|
| 287 |
-
if not keep_indices:
|
| 288 |
-
return None
|
| 289 |
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
convert_tokens = getattr(tokenizer_impl, "convert_ids_to_tokens", None)
|
| 293 |
-
if convert_tokens is None:
|
| 294 |
-
return None
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
# Cap the visualization to prevent massive heatmaps
|
| 303 |
-
max_tokens = 40
|
| 304 |
-
if len(summary_tokens) > max_tokens:
|
| 305 |
-
summary_tokens = summary_tokens[:max_tokens]
|
| 306 |
-
pruned_matrix = pruned_matrix[:max_tokens, :]
|
| 307 |
-
if len(source_tokens) > max_tokens:
|
| 308 |
-
source_tokens = source_tokens[:max_tokens]
|
| 309 |
-
pruned_matrix = pruned_matrix[:, :max_tokens]
|
| 310 |
-
|
| 311 |
-
height = max(4.0, 0.3 * len(summary_tokens))
|
| 312 |
-
width = max(6.0, 0.3 * len(source_tokens))
|
| 313 |
-
fig, ax = plt.subplots(figsize=(width, height))
|
| 314 |
-
sns.heatmap(
|
| 315 |
-
pruned_matrix,
|
| 316 |
-
cmap="mako",
|
| 317 |
-
xticklabels=source_tokens,
|
| 318 |
-
yticklabels=summary_tokens,
|
| 319 |
-
ax=ax,
|
| 320 |
-
cbar_kws={"label": "Attention"},
|
| 321 |
-
)
|
| 322 |
-
ax.set_xlabel("Input Tokens")
|
| 323 |
-
ax.set_ylabel("Summary Tokens")
|
| 324 |
-
ax.set_title("Cross-Attention (decoder last layer)")
|
| 325 |
-
ax.tick_params(axis="x", rotation=90)
|
| 326 |
-
ax.tick_params(axis="y", rotation=0)
|
| 327 |
-
fig.tight_layout()
|
| 328 |
-
return fig
|
| 329 |
-
|
| 330 |
-
except Exception as exc:
|
| 331 |
-
logger.error("Unable to build attention heatmap: %s", exc, exc_info=True)
|
| 332 |
-
return render_message_figure("Unable to render attention heatmap for this example.")
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
def render_message_figure(message: str) -> Figure:
|
| 336 |
-
fig, ax = plt.subplots(figsize=(6, 2))
|
| 337 |
-
ax.axis("off")
|
| 338 |
-
ax.text(0.5, 0.5, message, ha="center", va="center", wrap=True)
|
| 339 |
-
fig.tight_layout()
|
| 340 |
-
return fig
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
def prepare_download(
|
| 344 |
-
text: str,
|
| 345 |
-
summary: str,
|
| 346 |
-
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
|
| 347 |
-
topic: TopicPrediction | dict[str, float | str],
|
| 348 |
-
*,
|
| 349 |
-
neural_summary: str | None = None,
|
| 350 |
-
fallback_summary: str | None = None,
|
| 351 |
-
) -> str:
|
| 352 |
-
if isinstance(emotions, EmotionPrediction):
|
| 353 |
-
emotion_payload = {
|
| 354 |
-
"labels": list(emotions.labels),
|
| 355 |
-
"scores": list(emotions.scores),
|
| 356 |
-
}
|
| 357 |
-
else:
|
| 358 |
-
emotion_payload = {
|
| 359 |
-
"labels": list(emotions.get("labels", [])),
|
| 360 |
-
"scores": list(emotions.get("scores", [])),
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
if isinstance(topic, TopicPrediction):
|
| 364 |
-
topic_payload = {"label": topic.label, "confidence": topic.confidence}
|
| 365 |
-
else:
|
| 366 |
-
topic_payload = {
|
| 367 |
-
"label": str(topic.get("label", topic.get("topic", "Unknown"))),
|
| 368 |
-
"confidence": float(topic.get("confidence", topic.get("score", 0.0))),
|
| 369 |
-
}
|
| 370 |
-
|
| 371 |
-
payload = {
|
| 372 |
-
"original_text": text,
|
| 373 |
-
"summary": summary,
|
| 374 |
-
"neural_summary": neural_summary,
|
| 375 |
-
"fallback_summary": fallback_summary,
|
| 376 |
-
"emotions": emotion_payload,
|
| 377 |
-
"topic": topic_payload,
|
| 378 |
-
}
|
| 379 |
-
|
| 380 |
-
with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
|
| 381 |
-
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
| 382 |
-
return handle.name
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
def load_visualization_gallery() -> tuple[list[str], str]:
|
| 386 |
-
"""Collect visualization images produced by model tests."""
|
| 387 |
-
items: list[str] = []
|
| 388 |
-
missing: list[str] = []
|
| 389 |
-
for filename, _label in VISUALIZATION_ASSETS:
|
| 390 |
-
path = VISUALIZATION_DIR / filename
|
| 391 |
-
if path.exists():
|
| 392 |
-
items.append(str(path))
|
| 393 |
-
else:
|
| 394 |
-
missing.append(filename)
|
| 395 |
-
|
| 396 |
-
if items:
|
| 397 |
-
status = f"Loaded {len(items)} visualization(s) from {VISUALIZATION_DIR}."
|
| 398 |
-
else:
|
| 399 |
-
status = (
|
| 400 |
-
"No visualization PNGs found in the outputs/ directory. "
|
| 401 |
-
"Ensure tests/test_models/* have produced the PNGs and that they are available on this host."
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
if missing:
|
| 405 |
-
status += f" Missing files: {', '.join(missing)}."
|
| 406 |
-
|
| 407 |
-
return items, status
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
|
| 411 |
-
content = text.strip()
|
| 412 |
-
if not content:
|
| 413 |
-
return "(Input text was empty.)"
|
| 414 |
-
|
| 415 |
-
sentences = re.split(r"(?<=[.!?])\s+", content)
|
| 416 |
-
fragments: list[str] = []
|
| 417 |
-
total = 0
|
| 418 |
-
for sentence in sentences:
|
| 419 |
-
if not sentence:
|
| 420 |
-
continue
|
| 421 |
-
candidate = sentence if sentence.endswith((".", "!", "?")) else f"{sentence}."
|
| 422 |
-
if total + len(candidate) > max_chars and fragments:
|
| 423 |
-
break
|
| 424 |
-
fragments.append(candidate)
|
| 425 |
-
total += len(candidate)
|
| 426 |
-
|
| 427 |
-
if not fragments:
|
| 428 |
-
return content[:max_chars]
|
| 429 |
-
return " ".join(fragments)
|
| 430 |
|
| 431 |
|
| 432 |
-
def
|
| 433 |
-
"""Load metrics
|
| 434 |
if not EVAL_REPORT_PATH.exists():
|
| 435 |
-
|
| 436 |
-
f"Evaluation report not found at {EVAL_REPORT_PATH}. Run scripts/evaluate.py first."
|
| 437 |
-
)
|
| 438 |
-
return error_msg, "", None, error_msg
|
| 439 |
|
| 440 |
try:
|
| 441 |
-
with
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
# Build topic classification report markdown table
|
| 462 |
-
topic_report = report["topic"]["classification_report"]
|
| 463 |
-
topic_lines = [
|
| 464 |
-
"| Label | Precision | Recall | F1-Score | Support |",
|
| 465 |
-
"|-------|-----------|--------|----------|---------|",
|
| 466 |
-
]
|
| 467 |
-
for label, metrics in topic_report.items():
|
| 468 |
-
if isinstance(metrics, dict) and "precision" in metrics:
|
| 469 |
-
topic_lines.append(
|
| 470 |
-
f"| {label} | {metrics['precision']:.4f} | {metrics['recall']:.4f} | {metrics['f1-score']:.4f} | {int(metrics.get('support', 0))} |"
|
| 471 |
-
)
|
| 472 |
-
topic_md = "\n".join(topic_lines)
|
| 473 |
-
|
| 474 |
-
# Confusion Matrix
|
| 475 |
-
cm_image = str(CONFUSION_MATRIX_PATH) if CONFUSION_MATRIX_PATH.exists() else None
|
| 476 |
-
|
| 477 |
-
# Metadata
|
| 478 |
-
meta_str = f"Split: {report.get('split', 'unknown')}\nLast updated: {datetime.fromtimestamp(EVAL_REPORT_PATH.stat().st_mtime).isoformat()}"
|
| 479 |
-
|
| 480 |
-
return summary_md, topic_md, cm_image, meta_str
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
SAMPLE_TEXT = (
|
| 484 |
-
"Artificial intelligence is rapidly transforming the technology landscape. "
|
| 485 |
-
"Machine learning algorithms are now capable of processing vast amounts of data, "
|
| 486 |
-
"identifying patterns, and making predictions with unprecedented accuracy. "
|
| 487 |
-
"From healthcare diagnostics to financial forecasting, AI applications are "
|
| 488 |
-
"revolutionizing industries worldwide. However, ethical considerations around "
|
| 489 |
-
"privacy, bias, and transparency remain critical challenges that must be addressed "
|
| 490 |
-
"as these technologies continue to evolve."
|
| 491 |
-
)
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
def create_interface() -> gr.Blocks:
|
| 495 |
-
with gr.Blocks(title="LexiMind Demo", theme=Soft()) as demo:
|
| 496 |
-
gr.Markdown(
|
| 497 |
-
"""
|
| 498 |
-
# LexiMind NLP Demo
|
| 499 |
-
|
| 500 |
-
This demo streams the raw outputs from the saved LexiMind checkpoint.
|
| 501 |
-
Results may be noisy; retraining is recommended for production use.
|
| 502 |
-
"""
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
_, initial_visual_status = load_visualization_gallery()
|
| 506 |
-
summary_md, topic_md, cm_image, metrics_meta = load_metrics_report_as_markdown()
|
| 507 |
-
|
| 508 |
-
with gr.Row():
|
| 509 |
-
with gr.Column(scale=1):
|
| 510 |
-
input_text = gr.Textbox(
|
| 511 |
-
label="Input Text",
|
| 512 |
-
lines=10,
|
| 513 |
-
value=SAMPLE_TEXT,
|
| 514 |
-
placeholder="Paste or type your text here...",
|
| 515 |
)
|
| 516 |
-
token_box = gr.Textbox(label="Token Count", value="Tokens: 0", interactive=False)
|
| 517 |
-
analyze_btn = gr.Button("Run Analysis", variant="primary")
|
| 518 |
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
summary_output = gr.HTML(label="Summary")
|
| 523 |
-
with gr.TabItem("Emotions"):
|
| 524 |
-
emotion_output = gr.Plot(label="Emotion Probabilities")
|
| 525 |
-
with gr.TabItem("Topic"):
|
| 526 |
-
topic_output = gr.Markdown(label="Topic Prediction")
|
| 527 |
-
with gr.TabItem("Attention"):
|
| 528 |
-
attention_output = gr.Plot(label="Attention Heatmap")
|
| 529 |
-
gr.Markdown("*Shows decoder attention if a summary is available.*")
|
| 530 |
-
with gr.TabItem("Model Performance"):
|
| 531 |
-
gr.Markdown("### Overall Metrics")
|
| 532 |
-
metrics_table = gr.Markdown(value=summary_md)
|
| 533 |
-
gr.Markdown("### Topic Classification Report")
|
| 534 |
-
topic_table = gr.Markdown(value=topic_md)
|
| 535 |
-
gr.Markdown("### Topic Confusion Matrix")
|
| 536 |
-
cm_output = gr.Image(value=cm_image, label="Confusion Matrix")
|
| 537 |
-
gr.Markdown("### Metadata")
|
| 538 |
-
metrics_meta_text = gr.Textbox(
|
| 539 |
-
value=metrics_meta,
|
| 540 |
-
label="Info",
|
| 541 |
-
interactive=False,
|
| 542 |
-
lines=2,
|
| 543 |
-
)
|
| 544 |
-
refresh_metrics = gr.Button("Refresh Metrics")
|
| 545 |
|
| 546 |
-
with gr.TabItem("Model Visuals"):
|
| 547 |
-
gr.Markdown(
|
| 548 |
-
"These PNGs come from the visualization-focused tests in `tests/test_models` and are consumed as-is."
|
| 549 |
-
)
|
| 550 |
-
gr.Markdown(initial_visual_status)
|
| 551 |
|
| 552 |
-
|
| 553 |
-
analyze_btn.click(
|
| 554 |
-
fn=predict,
|
| 555 |
-
inputs=[input_text],
|
| 556 |
-
outputs=[summary_output, emotion_output, topic_output, attention_output],
|
| 557 |
-
)
|
| 558 |
-
refresh_metrics.click(
|
| 559 |
-
fn=load_metrics_report_as_markdown,
|
| 560 |
-
inputs=None,
|
| 561 |
-
outputs=[metrics_table, topic_table, cm_output, metrics_meta_text],
|
| 562 |
-
)
|
| 563 |
-
return demo
|
| 564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
-
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
| 568 |
|
|
|
|
|
|
|
| 569 |
|
| 570 |
if __name__ == "__main__":
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
# On HuggingFace Spaces, share must be False (they handle routing)
|
| 574 |
-
# but we need to ensure server binds correctly
|
| 575 |
-
is_hf_space = os.environ.get("SPACE_ID") is not None
|
| 576 |
-
|
| 577 |
-
try:
|
| 578 |
-
get_pipeline()
|
| 579 |
-
demo.queue().launch(
|
| 580 |
-
server_name="0.0.0.0",
|
| 581 |
-
server_port=7860,
|
| 582 |
-
share=False,
|
| 583 |
-
allowed_paths=[str(OUTPUTS_DIR)],
|
| 584 |
-
)
|
| 585 |
-
except Exception as exc: # pragma: no cover - surfaced in console
|
| 586 |
-
logger.error("Failed to launch demo: %s", exc, exc_info=True)
|
| 587 |
-
raise
|
|
|
|
| 1 |
+
"""Minimal Gradio demo for LexiMind multitask model."""
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
|
|
|
| 6 |
import sys
|
|
|
|
| 7 |
from pathlib import Path
|
|
|
|
|
|
|
| 8 |
|
| 9 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
| 11 |
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 12 |
+
PROJECT_ROOT = SCRIPT_DIR.parent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
if str(PROJECT_ROOT) not in sys.path:
|
| 14 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from huggingface_hub import hf_hub_download
|
| 17 |
|
| 18 |
from src.inference.factory import create_inference_pipeline
|
|
|
|
| 19 |
from src.utils.logging import configure_logging, get_logger
|
| 20 |
|
| 21 |
configure_logging()
|
| 22 |
logger = get_logger(__name__)
|
| 23 |
|
| 24 |
+
OUTPUTS_DIR = PROJECT_ROOT / "outputs"
|
| 25 |
+
EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
|
| 26 |
|
| 27 |
+
_pipeline = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
+
def get_pipeline():
|
| 31 |
global _pipeline
|
| 32 |
if _pipeline is None:
|
|
|
|
|
|
|
|
|
|
| 33 |
checkpoint_path = Path("checkpoints/best.pt")
|
| 34 |
if not checkpoint_path.exists():
|
| 35 |
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
hf_hub_download(
|
| 37 |
+
repo_id="OliverPerrin/LexiMind-Model",
|
| 38 |
+
filename="best.pt",
|
| 39 |
+
local_dir="checkpoints",
|
| 40 |
+
local_dir_use_symlinks=False,
|
| 41 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
_pipeline, _ = create_inference_pipeline(
|
| 43 |
tokenizer_dir="artifacts/hf_tokenizer/",
|
| 44 |
checkpoint_path="checkpoints/best.pt",
|
| 45 |
labels_path="artifacts/labels.json",
|
| 46 |
)
|
|
|
|
| 47 |
return _pipeline
|
| 48 |
|
| 49 |
|
| 50 |
+
def analyze(text: str) -> str:
|
| 51 |
+
"""Run all three tasks and return results as formatted text."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
if not text or not text.strip():
|
| 53 |
+
return "Please enter some text to analyze."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
try:
|
| 56 |
+
pipe = get_pipeline()
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
# Summarization
|
| 59 |
+
summary = pipe.summarize([text], max_length=128)[0].strip() or "(empty)"
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
# Emotion detection
|
| 62 |
+
emotions = pipe.predict_emotions([text], threshold=0.5)[0]
|
| 63 |
+
if emotions.labels:
|
| 64 |
+
emotion_str = ", ".join(
|
| 65 |
+
f"{lbl} ({score:.1%})"
|
| 66 |
+
for lbl, score in zip(emotions.labels, emotions.scores, strict=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
else:
|
| 69 |
+
emotion_str = "No strong emotions detected"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# Topic classification
|
| 72 |
+
topic = pipe.predict_topics([text])[0]
|
| 73 |
+
topic_str = f"{topic.label} ({topic.confidence:.1%})"
|
| 74 |
|
| 75 |
+
return f"""## Summary
|
| 76 |
+
{summary}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
## Detected Emotions
|
| 79 |
+
{emotion_str}
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
## Topic
|
| 82 |
+
{topic_str}
|
| 83 |
+
"""
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error("Analysis failed: %s", e, exc_info=True)
|
| 86 |
+
return f"Error: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
+
def get_metrics() -> str:
|
| 90 |
+
"""Load evaluation metrics as markdown."""
|
| 91 |
if not EVAL_REPORT_PATH.exists():
|
| 92 |
+
return "No evaluation report found. Run `scripts/evaluate.py` first."
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
try:
|
| 95 |
+
with open(EVAL_REPORT_PATH) as f:
|
| 96 |
+
r = json.load(f)
|
| 97 |
+
|
| 98 |
+
lines = [
|
| 99 |
+
"## Model Performance\n",
|
| 100 |
+
"| Task | Metric | Score |",
|
| 101 |
+
"|------|--------|-------|",
|
| 102 |
+
f"| Summarization | ROUGE-Like | {r['summarization']['rouge_like']:.4f} |",
|
| 103 |
+
f"| Summarization | BLEU | {r['summarization']['bleu']:.4f} |",
|
| 104 |
+
f"| Emotion | F1 Macro | {r['emotion']['f1_macro']:.4f} |",
|
| 105 |
+
f"| Topic | Accuracy | {r['topic']['accuracy']:.4f} |",
|
| 106 |
+
"",
|
| 107 |
+
"### Topic Classification Details\n",
|
| 108 |
+
"| Label | Precision | Recall | F1 |",
|
| 109 |
+
"|-------|-----------|--------|-----|",
|
| 110 |
+
]
|
| 111 |
+
for k, v in r["topic"]["classification_report"].items():
|
| 112 |
+
if isinstance(v, dict) and "precision" in v:
|
| 113 |
+
lines.append(
|
| 114 |
+
f"| {k} | {v['precision']:.3f} | {v['recall']:.3f} | {v['f1-score']:.3f} |"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
)
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
return "\n".join(lines)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
return f"Error loading metrics: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
SAMPLE = """Artificial intelligence is rapidly transforming technology. Machine learning algorithms process vast amounts of data, identifying patterns with unprecedented accuracy. From healthcare to finance, AI is revolutionizing industries worldwide. However, ethical considerations around privacy and bias remain critical challenges."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
with gr.Blocks(title="LexiMind Demo") as demo:
|
| 125 |
+
gr.Markdown(
|
| 126 |
+
"# LexiMind NLP Demo\nMulti-task model: summarization, emotion detection, topic classification."
|
| 127 |
+
)
|
| 128 |
|
| 129 |
+
with gr.Tab("Analyze"):
|
| 130 |
+
text_input = gr.Textbox(label="Input Text", lines=6, value=SAMPLE)
|
| 131 |
+
analyze_btn = gr.Button("Analyze", variant="primary")
|
| 132 |
+
output = gr.Markdown(label="Results")
|
| 133 |
+
analyze_btn.click(fn=analyze, inputs=text_input, outputs=output)
|
| 134 |
|
| 135 |
+
with gr.Tab("Metrics"):
|
| 136 |
+
gr.Markdown(get_metrics())
|
| 137 |
|
| 138 |
if __name__ == "__main__":
|
| 139 |
+
get_pipeline() # Pre-load
|
| 140 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|