File size: 13,899 Bytes
7977c7d
 
 
5dc2ed8
7977c7d
5dc2ed8
1ec7405
7977c7d
 
1ec7405
7977c7d
ee1a8a3
8951fba
 
4b33d4d
 
 
ded097b
4b33d4d
 
7977c7d
 
0df948f
18fc263
7977c7d
8951fba
 
4b33d4d
d18b34d
00d412c
4b33d4d
 
 
 
 
 
7977c7d
 
18fc263
 
1ec7405
8951fba
5dc2ed8
1ec7405
 
 
5dc2ed8
7977c7d
 
 
18fc263
8951fba
3318356
18fc263
7977c7d
8951fba
7977c7d
 
 
 
 
 
 
 
 
 
 
8951fba
7977c7d
 
 
 
 
d9dbe7c
7977c7d
4b33d4d
 
8951fba
7977c7d
 
 
5dc2ed8
 
4b33d4d
1ec7405
8951fba
4b33d4d
18fc263
ded097b
5dc2ed8
1ec7405
 
 
 
 
7977c7d
 
1ec7405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18fc263
1ec7405
 
 
 
 
93b9242
1ec7405
8951fba
5dc2ed8
1ec7405
4b33d4d
5dc2ed8
4b33d4d
18fc263
 
5dc2ed8
a504116
 
5dc2ed8
 
1ec7405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00d412c
1ec7405
 
 
 
 
 
 
 
 
 
 
18fc263
1ec7405
 
 
 
 
 
 
5dc2ed8
1ec7405
8951fba
5dc2ed8
1ec7405
 
 
 
d18b34d
 
7977c7d
bc94c66
5dc2ed8
1ec7405
5dc2ed8
 
18fc263
5dc2ed8
 
 
 
1ec7405
 
 
 
5dc2ed8
18fc263
bc94c66
5dc2ed8
 
 
1ec7405
5dc2ed8
1ec7405
 
 
5dc2ed8
 
1ec7405
 
 
 
 
 
 
5dc2ed8
1ec7405
 
 
 
 
 
 
5dc2ed8
 
1ec7405
 
 
 
 
 
 
 
 
 
 
 
 
5dc2ed8
 
 
 
 
 
 
 
 
1ec7405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dc2ed8
 
 
 
 
 
 
1ec7405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dc2ed8
1ec7405
5dc2ed8
1ec7405
 
 
 
 
5dc2ed8
 
bc94c66
5dc2ed8
 
 
 
 
 
1ec7405
 
 
 
 
 
 
 
 
5dc2ed8
1ec7405
 
 
5dc2ed8
 
 
 
1ec7405
5dc2ed8
1ec7405
5dc2ed8
1ec7405
5dc2ed8
 
bc94c66
7977c7d
 
 
bc94c66
7977c7d
18fc263
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
Gradio demo for LexiMind multi-task NLP model.

Showcases the model's capabilities across three tasks:
- Summarization: Generates concise summaries of input text
- Emotion Detection: Multi-label emotion classification
- Topic Classification: Categorizes text into topics

Author: Oliver Perrin
Date: 2025-12-05
"""

from __future__ import annotations

import json
import sys
from pathlib import Path

import gradio as gr

# --------------- Path Setup ---------------

SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from huggingface_hub import hf_hub_download

from src.inference.factory import create_inference_pipeline
from src.utils.logging import configure_logging, get_logger

configure_logging()
logger = get_logger(__name__)

# --------------- Constants ---------------

OUTPUTS_DIR = PROJECT_ROOT / "outputs"
EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
TRAINING_HISTORY_PATH = OUTPUTS_DIR / "training_history.json"

SAMPLE_TEXTS = [
    "Global markets tumbled today as investors reacted to rising inflation concerns. The Federal Reserve hinted at potential interest rate hikes, sending shockwaves through technology and banking sectors. Analysts predict continued volatility as economic uncertainty persists.",
    "Scientists at MIT have developed a breakthrough quantum computing chip that operates at room temperature. This advancement could revolutionize drug discovery, cryptography, and artificial intelligence. The research team published their findings in Nature.",
    "The championship game ended in dramatic fashion as the underdog team scored in the final seconds to secure victory. Fans rushed the field in celebration, marking the team's first title in 25 years.",
]

# --------------- Pipeline Management ---------------

_pipeline = None


def get_pipeline():
    """Lazy-load the inference pipeline, downloading checkpoint if needed."""
    global _pipeline
    if _pipeline is not None:
        return _pipeline

    checkpoint_path = Path("checkpoints/best.pt")

    if not checkpoint_path.exists():
        checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
        hf_hub_download(
            repo_id="OliverPerrin/LexiMind-Model",
            filename="best.pt",
            local_dir="checkpoints",
        )

    _pipeline, _ = create_inference_pipeline(
        tokenizer_dir="artifacts/hf_tokenizer/",
        checkpoint_path="checkpoints/best.pt",
        labels_path="artifacts/labels.json",
        model_config_path="configs/model/base.yaml",
    )
    return _pipeline


# --------------- Core Functions ---------------


def analyze(text: str) -> tuple[str, str, str]:
    """Run all three tasks and return formatted results."""
    if not text or not text.strip():
        return "Please enter text above to analyze.", "", ""

    try:
        pipe = get_pipeline()

        # Run tasks
        summary = pipe.summarize([text], max_length=128)[0].strip()
        if not summary:
            summary = "(Unable to generate summary)"

        emotions = pipe.predict_emotions([text], threshold=0.3)[0]  # Lower threshold
        topic = pipe.predict_topics([text])[0]

        # Format emotions with emoji
        emotion_emoji = {
            "joy": "😊",
            "love": "❀️",
            "anger": "😠",
            "fear": "😨",
            "sadness": "😒",
            "surprise": "😲",
            "neutral": "😐",
            "admiration": "🀩",
            "amusement": "πŸ˜„",
            "annoyance": "😀",
            "approval": "πŸ‘",
            "caring": "πŸ€—",
            "confusion": "πŸ˜•",
            "curiosity": "πŸ€”",
            "desire": "😍",
            "disappointment": "😞",
            "disapproval": "πŸ‘Ž",
            "disgust": "🀒",
            "embarrassment": "😳",
            "excitement": "πŸŽ‰",
            "gratitude": "πŸ™",
            "grief": "😭",
            "nervousness": "οΏ½οΏ½",
            "optimism": "🌟",
            "pride": "🦁",
            "realization": "πŸ’‘",
            "relief": "😌",
            "remorse": "πŸ˜”",
        }

        if emotions.labels:
            emotion_parts = []
            for lbl, score in zip(emotions.labels[:5], emotions.scores[:5], strict=False):
                emoji = emotion_emoji.get(lbl.lower(), "β€’")
                emotion_parts.append(f"{emoji} **{lbl.title()}** ({score:.0%})")
            emotion_str = "\n".join(emotion_parts)
        else:
            emotion_str = "😐 No strong emotions detected"

        # Format topic
        topic_str = f"**{topic.label}**\n\nConfidence: {topic.confidence:.0%}"

        return summary, emotion_str, topic_str

    except Exception as e:
        logger.error("Analysis failed: %s", e, exc_info=True)
        return f"Error: {e}", "", ""


def load_metrics() -> str:
    """Load evaluation metrics and format as markdown."""
    # Load evaluation report
    eval_metrics = {}
    if EVAL_REPORT_PATH.exists():
        try:
            with open(EVAL_REPORT_PATH) as f:
                eval_metrics = json.load(f)
        except Exception:
            pass

    # Load training history
    train_metrics = {}
    if TRAINING_HISTORY_PATH.exists():
        try:
            with open(TRAINING_HISTORY_PATH) as f:
                train_metrics = json.load(f)
        except Exception:
            pass

    # Get final validation metrics
    val_final = train_metrics.get("val_epoch_3", {})

    md = """
## πŸ“ˆ Model Performance

### Training Results (3 Epochs)

| Task | Metric | Final Score |
|------|--------|-------------|
| **Topic Classification** | Accuracy | **{topic_acc:.1%}** |
| **Emotion Detection** | F1 (training) | {emo_f1:.1%} |
| **Summarization** | ROUGE-like | {rouge:.1%} |

### Evaluation Results

| Metric | Value |
|--------|-------|
| Topic Accuracy | **{eval_topic:.1%}** |
| Emotion F1 (macro) | {eval_emo:.1%} |
| ROUGE-like | {eval_rouge:.1%} |
| BLEU | {eval_bleu:.3f} |

---

### Topic Classification Details

| Category | Precision | Recall | F1 |
|----------|-----------|--------|-----|
""".format(
        topic_acc=val_final.get("topic_accuracy", 0),
        emo_f1=val_final.get("emotion_f1", 0),
        rouge=val_final.get("summarization_rouge_like", 0),
        eval_topic=eval_metrics.get("topic", {}).get("accuracy", 0),
        eval_emo=eval_metrics.get("emotion", {}).get("f1_macro", 0),
        eval_rouge=eval_metrics.get("summarization", {}).get("rouge_like", 0),
        eval_bleu=eval_metrics.get("summarization", {}).get("bleu", 0),
    )

    # Add per-class metrics
    topic_report = eval_metrics.get("topic", {}).get("classification_report", {})
    for cat, metrics in topic_report.items():
        if cat in ["macro avg", "weighted avg", "micro avg"]:
            continue
        if isinstance(metrics, dict):
            md += f"| {cat} | {metrics.get('precision', 0):.1%} | {metrics.get('recall', 0):.1%} | {metrics.get('f1-score', 0):.1%} |\n"

    return md


def get_viz_path(filename: str) -> str | None:
    """Get visualization path if file exists."""
    path = OUTPUTS_DIR / filename
    return str(path) if path.exists() else None


# --------------- Gradio Interface ---------------

with gr.Blocks(
    title="LexiMind - Multi-Task NLP",
    theme=gr.themes.Soft(),
) as demo:
    gr.Markdown(
        """
        # 🧠 LexiMind
        ### Multi-Task Transformer for Document Analysis
        
        A custom encoder-decoder Transformer trained on **summarization**, **emotion detection** (28 classes),
        and **topic classification** (10 categories). Built from scratch with PyTorch.
        
        > ⚠️ **Note**: Summarization is experimental - the model works best on news-style articles.
        """
    )

    # --------------- Try It Tab ---------------
    with gr.Tab("πŸš€ Try It"):
        with gr.Row():
            with gr.Column(scale=3):
                text_input = gr.Textbox(
                    label="πŸ“ Input Text",
                    lines=6,
                    placeholder="Enter or paste text to analyze (works best with news articles)...",
                    value=SAMPLE_TEXTS[0],
                )
                analyze_btn = gr.Button(
                    "πŸ” Analyze",
                    variant="primary",
                    size="sm",
                )

                gr.Markdown("**Sample Texts** (click to use):")
                with gr.Row():
                    sample1_btn = gr.Button("πŸ“° Markets", size="sm", variant="secondary")
                    sample2_btn = gr.Button("πŸ”¬ Science", size="sm", variant="secondary")
                    sample3_btn = gr.Button("πŸ† Sports", size="sm", variant="secondary")

                sample1_btn.click(fn=lambda: SAMPLE_TEXTS[0], outputs=text_input)
                sample2_btn.click(fn=lambda: SAMPLE_TEXTS[1], outputs=text_input)
                sample3_btn.click(fn=lambda: SAMPLE_TEXTS[2], outputs=text_input)

            with gr.Column(scale=2):
                gr.Markdown("### Results")
                summary_out = gr.Textbox(
                    label="πŸ“ Summary",
                    lines=3,
                    interactive=False,
                )
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("**😊 Emotions**")
                        emotion_out = gr.Markdown(value="*Run analysis*")
                    with gr.Column():
                        gr.Markdown("**πŸ“‚ Topic**")
                        topic_out = gr.Markdown(value="*Run analysis*")

        analyze_btn.click(
            fn=analyze,
            inputs=text_input,
            outputs=[summary_out, emotion_out, topic_out],
        )

    # --------------- Metrics Tab ---------------
    with gr.Tab("πŸ“Š Metrics"):
        with gr.Row():
            with gr.Column(scale=2):
                gr.Markdown(load_metrics())
            with gr.Column(scale=1):
                confusion_path = get_viz_path("topic_confusion_matrix.png")
                if confusion_path:
                    gr.Image(confusion_path, label="Confusion Matrix", show_label=True)

    # --------------- Visualizations Tab ---------------
    with gr.Tab("🎨 Visualizations"):
        gr.Markdown("### Model Internals")

        with gr.Row():
            attn_path = get_viz_path("attention_visualization.png")
            if attn_path:
                gr.Image(attn_path, label="Self-Attention Pattern")

            pos_path = get_viz_path("positional_encoding_heatmap.png")
            if pos_path:
                gr.Image(pos_path, label="Positional Encodings")

        with gr.Row():
            multi_path = get_viz_path("multihead_attention_visualization.png")
            if multi_path:
                gr.Image(multi_path, label="Multi-Head Attention")

            single_path = get_viz_path("single_vs_multihead.png")
            if single_path:
                gr.Image(single_path, label="Single vs Multi-Head Comparison")

    # --------------- Architecture Tab ---------------
    with gr.Tab("πŸ”§ Architecture"):
        gr.Markdown(
            """
            ### Model Architecture
            
            | Component | Configuration |
            |-----------|---------------|
            | **Base** | Custom Transformer (encoder-decoder) |
            | **Initialization** | FLAN-T5-base weights |
            | **Encoder** | 6 layers, 768 hidden dim, 12 heads |
            | **Decoder** | 6 layers with cross-attention |
            | **Activation** | Gated-GELU |
            | **Position** | Relative position bias |
            
            ### Training Configuration
            
            | Setting | Value |
            |---------|-------|
            | **Optimizer** | AdamW (lr=2e-5, wd=0.01) |
            | **Scheduler** | Cosine with 1000 warmup steps |
            | **Batch Size** | 14 Γ— 3 accumulation = 42 effective |
            | **Precision** | TF32 (Ampere GPU) |
            | **Compilation** | torch.compile (inductor) |
            
            ### Datasets
            
            | Task | Dataset | Size |
            |------|---------|------|
            | **Summarization** | CNN/DailyMail + BookSum | ~110K |
            | **Emotion** | GoEmotions | ~43K (28 labels) |
            | **Topic** | Yahoo Answers | ~200K (10 classes) |
            """
        )

    # --------------- About Tab ---------------
    with gr.Tab("ℹ️ About"):
        gr.Markdown(
            """
            ### About LexiMind
            
            LexiMind is a **portfolio project** demonstrating end-to-end machine learning engineering:
            
            βœ… Custom Transformer implementation from scratch  
            βœ… Multi-task learning with shared encoder  
            βœ… Production-ready inference pipeline  
            βœ… Comprehensive evaluation and visualization  
            βœ… CI/CD with GitHub Actions  
            
            ### Known Limitations
            
            - **Summarization** quality is limited (needs more training epochs)
            - **Emotion detection** has low F1 due to class imbalance in GoEmotions
            - Best results on **news-style text** (training domain)
            
            ### Links
            
            - πŸ”— [GitHub Repository](https://github.com/OliverPerrin/LexiMind)
            - πŸ€— [Model on HuggingFace](https://huggingface.co/OliverPerrin/LexiMind-Model)
            
            ---
            
            **Built by Oliver Perrin** | December 2025
            """
        )


# --------------- Entry Point ---------------

if __name__ == "__main__":
    get_pipeline()  # Pre-load to fail fast if checkpoint missing
    demo.launch(server_name="0.0.0.0", server_port=7860)