OliverPerrin commited on
Commit
1ec7405
·
1 Parent(s): 701dfc6

Update LexiMind: improved training, model architecture, and evaluation

Browse files
README.md CHANGED
@@ -8,7 +8,7 @@ app_file: scripts/demo_gradio.py
8
  pinned: false
9
  ---
10
 
11
- # LexiMind: A Multi-Task NLP Model
12
 
13
  LexiMind is a state-of-the-art Natural Language Processing model designed for complex document understanding. It features a **custom-built Transformer architecture** initialized with weights from Google's **FLAN-T5**, combining the flexibility of from-scratch implementation with the power of modern pre-trained models.
14
 
@@ -18,32 +18,37 @@ This project is built with industry-standard MLOps practices, including configur
18
 
19
  ## Core Features
20
 
21
- * **Abstractive Summarization:** Generates concise, coherent summaries of long-form text using encoder-decoder attention.
22
- * **Emotion Classification:** Identifies emotions (Joy, Sadness, Anger, Fear, Love, Surprise) conveyed in a document.
23
- * **Topic Clustering:** Classifies documents into thematic categories (World, Sports, Business, Sci/Tech).
24
 
25
  ## Model Architecture
26
 
27
  LexiMind implements a **from-scratch Transformer** with modern architectural choices:
28
 
29
  ### Custom Transformer Features
30
- - **Pre-Layer Normalization (Pre-LN):** RMSNorm applied before each sublayer for stable training
31
- - **FlashAttention:** Via PyTorch 2.0's `scaled_dot_product_attention` for efficient computation
32
- - **Learned Positional Embeddings:** Trainable position representations
33
- - **Multi-Head Attention:** 12 heads with 768-dimensional representations
34
- - **RMSNorm:** Modern normalization without bias (more efficient than LayerNorm)
 
35
 
36
  ### Pre-trained Weight Initialization
 
37
  The model loads weights from **Google's FLAN-T5-base**, which provides:
38
- - Strong language understanding from instruction-tuning
39
- - Excellent performance on summarization and classification tasks
40
- - Encoder-decoder architecture matching our custom implementation
 
41
 
42
  ### Multi-Task Learning
 
43
  A shared encoder-decoder backbone with task-specific heads:
44
- - **Summarization Head:** Language modeling head with weight tying
45
- - **Emotion Head:** Mean-pooled classification with dropout
46
- - **Topic Head:** Mean-pooled classification with dropout
 
47
 
48
  ## Technical Specifications
49
 
@@ -64,29 +69,32 @@ A shared encoder-decoder backbone with task-specific heads:
64
 
65
  ### Prerequisites
66
 
67
- * Python 3.10+
68
- * Poetry for dependency management
69
- * Docker (for containerized deployment)
70
- * An NVIDIA GPU with CUDA support (for training and accelerated inference)
71
 
72
  ### Installation
73
 
74
- 1. **Clone the repository:**
75
- ```bash
76
- git clone https://github.com/OliverPerrin/LexiMind.git
77
- cd LexiMind
78
- ```
 
 
 
79
 
80
- 2. **Install dependencies:**
81
- ```bash
82
- poetry install
83
- ```
84
 
85
- 3. **Download and preprocess data:**
86
- ```bash
87
- poetry run python scripts/download_data.py
88
- poetry run python scripts/preprocess_data.py
89
- ```
 
90
 
91
  ## Usage
92
 
@@ -95,12 +103,13 @@ A shared encoder-decoder backbone with task-specific heads:
95
  All training and model parameters are managed via Hydra. Configurations are located in the `configs/` directory.
96
 
97
  Available configurations:
98
- - `model=base` - FLAN-T5-base (default, 12 layers)
99
- - `model=small` - Smaller model for testing (no pretrained weights)
100
- - `model=large` - FLAN-T5-large (24 layers, requires more VRAM)
101
- - `training=dev` - Quick development run
102
- - `training=medium` - Balanced training (~2-3 hours on RTX 4070)
103
- - `training=full` - Full training run
 
104
 
105
  ### Training
106
 
@@ -116,6 +125,9 @@ poetry run python scripts/train.py training=medium
116
 
117
  # Override parameters
118
  poetry run python scripts/train.py training.optimizer.lr=5e-5
 
 
 
119
  ```
120
 
121
  Experiments are automatically tracked with MLflow. View results with `mlflow ui`.
@@ -148,7 +160,7 @@ docker run -p 7860:7860 leximind
148
 
149
  ## Project Structure
150
 
151
- ```
152
  ├── configs/ # Hydra configuration files
153
  │ ├── model/ # Model architectures (base, small, large)
154
  │ ├── training/ # Training configs (dev, medium, full)
@@ -169,22 +181,33 @@ docker run -p 7860:7860 leximind
169
 
170
  ## Code Quality
171
 
172
- * **Ruff:** Fast linting and formatting
173
- * **MyPy:** Static type checking
174
- * **Pre-commit hooks:** Automated quality checks
 
175
 
176
  ```bash
 
177
  poetry run pre-commit install
 
 
 
 
 
 
 
 
 
178
  ```
179
 
180
  ## Performance Optimizations
181
 
182
- - **torch.compile:** JIT compilation with Inductor backend
183
- - **Mixed Precision:** bfloat16 training on Ampere/Ada GPUs
184
- - **TF32:** Enabled for RTX 30xx/40xx series
185
- - **KV-Cache:** Efficient autoregressive decoding
186
- - **FlashAttention:** Memory-efficient attention via SDPA
187
 
188
  ## License
189
 
190
- MIT License - see [LICENSE](LICENSE) for details.
 
8
  pinned: false
9
  ---
10
 
11
+ ## LexiMind: A Multi-Task NLP Model
12
 
13
  LexiMind is a state-of-the-art Natural Language Processing model designed for complex document understanding. It features a **custom-built Transformer architecture** initialized with weights from Google's **FLAN-T5**, combining the flexibility of from-scratch implementation with the power of modern pre-trained models.
14
 
 
18
 
19
  ## Core Features
20
 
21
+ * **Abstractive Summarization:** Generates concise, coherent summaries of long-form text using encoder-decoder attention.
22
+ * **Emotion Classification:** Identifies emotions (Joy, Sadness, Anger, Fear, Love, Surprise) conveyed in a document.
23
+ * **Topic Clustering:** Classifies documents into thematic categories (World, Sports, Business, Sci/Tech).
24
 
25
  ## Model Architecture
26
 
27
  LexiMind implements a **from-scratch Transformer** with modern architectural choices:
28
 
29
  ### Custom Transformer Features
30
+
31
+ * **Pre-Layer Normalization (Pre-LN):** RMSNorm applied before each sublayer for stable training
32
+ * **FlashAttention:** Via PyTorch 2.0's `scaled_dot_product_attention` for efficient computation
33
+ * **Learned Positional Embeddings:** Trainable position representations
34
+ * **Multi-Head Attention:** 12 heads with 768-dimensional representations
35
+ * **RMSNorm:** Modern normalization without bias (more efficient than LayerNorm)
36
 
37
  ### Pre-trained Weight Initialization
38
+
39
  The model loads weights from **Google's FLAN-T5-base**, which provides:
40
+
41
+ * Strong language understanding from instruction-tuning
42
+ * Excellent performance on summarization and classification tasks
43
+ * Encoder-decoder architecture matching our custom implementation
44
 
45
  ### Multi-Task Learning
46
+
47
  A shared encoder-decoder backbone with task-specific heads:
48
+
49
+ * **Summarization Head:** Language modeling head with weight tying
50
+ * **Emotion Head:** Mean-pooled classification with dropout
51
+ * **Topic Head:** Mean-pooled classification with dropout
52
 
53
  ## Technical Specifications
54
 
 
69
 
70
  ### Prerequisites
71
 
72
+ * Python 3.10+
73
+ * Poetry for dependency management
74
+ * Docker (for containerized deployment)
75
+ * An NVIDIA GPU with CUDA support (for training and accelerated inference)
76
 
77
  ### Installation
78
 
79
+ 1. **Clone the repository:**
80
+
81
+ ```bash
82
+ git clone https://github.com/OliverPerrin/LexiMind.git
83
+ cd LexiMind
84
+ ```
85
+
86
+ 2. **Install dependencies:**
87
 
88
+ ```bash
89
+ poetry install
90
+ ```
 
91
 
92
+ 3. **Download and preprocess data:**
93
+
94
+ ```bash
95
+ poetry run python scripts/download_data.py
96
+ poetry run python scripts/preprocess_data.py
97
+ ```
98
 
99
  ## Usage
100
 
 
103
  All training and model parameters are managed via Hydra. Configurations are located in the `configs/` directory.
104
 
105
  Available configurations:
106
+
107
+ * `model=base` - FLAN-T5-base (default, 12 layers)
108
+ * `model=small` - Smaller model for testing (no pretrained weights)
109
+ * `model=large` - FLAN-T5-large (24 layers, requires more VRAM)
110
+ * `training=dev` - Quick development run
111
+ * `training=medium` - Balanced training (~2-3 hours on RTX 4070)
112
+ * `training=full` - Full training run
113
 
114
  ### Training
115
 
 
125
 
126
  # Override parameters
127
  poetry run python scripts/train.py training.optimizer.lr=5e-5
128
+
129
+ # Resume from a checkpoint
130
+ poetry run python scripts/train.py training=full resume_from=checkpoints/epoch_5.pt
131
  ```
132
 
133
  Experiments are automatically tracked with MLflow. View results with `mlflow ui`.
 
160
 
161
  ## Project Structure
162
 
163
+ ```text
164
  ├── configs/ # Hydra configuration files
165
  │ ├── model/ # Model architectures (base, small, large)
166
  │ ├── training/ # Training configs (dev, medium, full)
 
181
 
182
  ## Code Quality
183
 
184
+ * **Ruff:** Fast linting and formatting
185
+ * **MyPy:** Static type checking
186
+ * **Pytest:** Full test suite covering data, models, and training
187
+ * **Pre-commit hooks:** Automated quality checks
188
 
189
  ```bash
190
+ # Install hooks
191
  poetry run pre-commit install
192
+
193
+ # Lint
194
+ poetry run ruff check .
195
+
196
+ # Type check
197
+ poetry run mypy .
198
+
199
+ # Tests
200
+ poetry run pytest
201
  ```
202
 
203
  ## Performance Optimizations
204
 
205
+ * **torch.compile:** JIT compilation with Inductor backend
206
+ * **Mixed Precision:** bfloat16 training on Ampere/Ada GPUs
207
+ * **TF32:** Enabled for RTX 30xx/40xx series
208
+ * **KV-Cache:** Efficient autoregressive decoding
209
+ * **FlashAttention:** Memory-efficient attention via SDPA
210
 
211
  ## License
212
 
213
+ MIT License - see [LICENSE](LICENSE) for details.
configs/config.yaml CHANGED
@@ -14,5 +14,6 @@ hydra:
14
  checkpoint_out: "checkpoints/best.pt"
15
  labels_out: "artifacts/labels.json"
16
  history_out: "outputs/training_history.json"
 
17
  device: "cuda"
18
  seed: 17
 
14
  checkpoint_out: "checkpoints/best.pt"
15
  labels_out: "artifacts/labels.json"
16
  history_out: "outputs/training_history.json"
17
+ resume_from: null
18
  device: "cuda"
19
  seed: 17
configs/model/base.yaml CHANGED
@@ -1,8 +1,10 @@
1
  # FLAN-T5-base architecture
2
- # 6 encoder layers, 6 decoder layers, 768 hidden dim
3
  d_model: 768
4
- num_encoder_layers: 6 # T5-base has 6 layers
5
- num_decoder_layers: 6 # T5-base has 6 layers
 
 
6
  num_attention_heads: 12
7
  ffn_dim: 2048 # T5 uses d_ff = 2048 for base model
8
  dropout: 0.1
@@ -10,3 +12,5 @@ activation: gated-gelu # T5/FLAN-T5 uses gated-gelu (GELU activation with gatin
10
  use_pretrained: true
11
  pretrained_model_name: google/flan-t5-base
12
  use_relative_position_bias: true # T5 uses relative position bias instead of absolute embeddings
 
 
 
1
  # FLAN-T5-base architecture
2
+ # 12 encoder layers, 12 decoder layers, 768 hidden dim
3
  d_model: 768
4
+ # Align vocab with FLAN-T5 padded size to avoid weight truncation
5
+ vocab_size: 32128
6
+ num_encoder_layers: 12 # T5-base has 12 layers
7
+ num_decoder_layers: 12 # T5-base has 12 layers
8
  num_attention_heads: 12
9
  ffn_dim: 2048 # T5 uses d_ff = 2048 for base model
10
  dropout: 0.1
 
12
  use_pretrained: true
13
  pretrained_model_name: google/flan-t5-base
14
  use_relative_position_bias: true # T5 uses relative position bias instead of absolute embeddings
15
+ gradient_checkpointing: false
16
+
configs/training/dev.yaml CHANGED
@@ -1,35 +1,45 @@
1
  # Development/Testing Configuration for FLAN-T5-base
2
  # Fast iteration for debugging and testing changes
3
- # Training time: ~3-5 minutes on RTX 4070 12GB
 
4
  # Use: python scripts/train.py training=dev
5
 
6
  dataloader:
7
- batch_size: 14
8
  shuffle: true
9
- num_workers: 6
10
  pin_memory: true
11
  persistent_workers: true
12
- prefetch_factor: 4
13
 
14
  optimizer:
15
  name: adamw
16
- lr: 2.0e-5
17
  weight_decay: 0.01
18
- eps: 1.0e-6
 
19
 
20
  scheduler:
21
  name: cosine
22
- warmup_steps: 50
23
 
24
  trainer:
25
- max_epochs: 1
26
  gradient_clip_norm: 1.0
27
- gradient_accumulation_steps: 4
28
  validation_max_length: 128
29
  label_smoothing: 0.1
30
  task_weights:
31
  summarization: 1.0
32
- emotion: 1.0
33
- topic: 1.0
34
- max_train_samples: 1000
35
- max_val_samples: 100
 
 
 
 
 
 
 
 
 
1
  # Development/Testing Configuration for FLAN-T5-base
2
  # Fast iteration for debugging and testing changes
3
+ # VRAM Usage: ~8-9GB peak (12GB available)
4
+ # Training time: ~10-15 minutes on RTX 4070 12GB
5
  # Use: python scripts/train.py training=dev
6
 
7
  dataloader:
8
+ batch_size: 5 # Conservative for 12GB VRAM
9
  shuffle: true
10
+ num_workers: 4
11
  pin_memory: true
12
  persistent_workers: true
13
+ prefetch_factor: 2
14
 
15
  optimizer:
16
  name: adamw
17
+ lr: 5.0e-5 # Higher LR for faster convergence in dev
18
  weight_decay: 0.01
19
+ eps: 1.0e-8
20
+ betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
+ warmup_steps: 100 # ~2% of training steps for smoother start
25
 
26
  trainer:
27
+ max_epochs: 3
28
  gradient_clip_norm: 1.0
29
+ gradient_accumulation_steps: 12 # Effective batch: 60 (5*12)
30
  validation_max_length: 128
31
  label_smoothing: 0.1
32
  task_weights:
33
  summarization: 1.0
34
+ emotion: 0.5
35
+ topic: 0.5
36
+ max_train_samples: 3000 # 3k samples for better validation
37
+ max_val_samples: 300
38
+ early_stopping_patience: 5 # Stop if no improvement
39
+ log_grad_norm_frequency: 100
40
+
41
+ # Disable compile for faster startup in dev
42
+ compile_encoder: false
43
+ compile_decoder: false
44
+
45
+ tokenizer_max_length: 512
configs/training/full.yaml CHANGED
@@ -1,33 +1,44 @@
1
  # Full Training Configuration for FLAN-T5-base
2
- # Complete training run on all data
3
- # Training time: ~4-6 hours on RTX 4070 12GB with inductor
 
4
  # Use: python scripts/train.py training=full
5
 
6
  dataloader:
7
- batch_size: 14
8
  shuffle: true
9
- num_workers: 6
10
  pin_memory: true
11
  persistent_workers: true
12
- prefetch_factor: 4
13
 
14
  optimizer:
15
  name: adamw
16
- lr: 2.0e-5
17
  weight_decay: 0.01
18
  eps: 1.0e-6
 
19
 
20
  scheduler:
21
  name: cosine
22
- warmup_steps: 1000
23
 
24
  trainer:
25
- max_epochs: 3
26
  gradient_clip_norm: 1.0
27
- gradient_accumulation_steps: 3 # Effective batch = 42
28
  validation_max_length: 128
29
  label_smoothing: 0.1
30
  task_weights:
31
- summarization: 1.0
32
  emotion: 1.0
33
- topic: 1.0
 
 
 
 
 
 
 
 
 
 
1
  # Full Training Configuration for FLAN-T5-base
2
+ # Complete training run on all available data
3
+ # VRAM Usage: ~10-11GB peak (12GB available)
4
+ # Training time: ~3-4 hours on RTX 4070 12GB with torch.compile
5
  # Use: python scripts/train.py training=full
6
 
7
  dataloader:
8
+ batch_size: 6 # Conservative for 12GB VRAM with torch.compile overhead
9
  shuffle: true
10
+ num_workers: 4
11
  pin_memory: true
12
  persistent_workers: true
13
+ prefetch_factor: 2
14
 
15
  optimizer:
16
  name: adamw
17
+ lr: 3.0e-5 # Higher LR with larger effective batch
18
  weight_decay: 0.01
19
  eps: 1.0e-6
20
+ betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
+ warmup_steps: 1000 # ~1% warmup for stability
25
 
26
  trainer:
27
+ max_epochs: 8 # More epochs for full dataset
28
  gradient_clip_norm: 1.0
29
+ gradient_accumulation_steps: 16 # Effective batch: 96 (6*16)
30
  validation_max_length: 128
31
  label_smoothing: 0.1
32
  task_weights:
33
+ summarization: 1.5 # Prioritize summarization quality
34
  emotion: 1.0
35
+ topic: 0.8
36
+ # No max_samples - use full dataset
37
+ early_stopping_patience: 3 # Stop if plateaus
38
+ log_grad_norm_frequency: 100
39
+
40
+ # Enable torch.compile for maximum speed
41
+ compile_encoder: true
42
+ compile_decoder: true
43
+
44
+ tokenizer_max_length: 512
configs/training/medium.yaml CHANGED
@@ -1,35 +1,45 @@
1
  # Medium Configuration for FLAN-T5-base
2
  # Balanced approach - good results in reasonable time
3
- # Training time: ~1.5-2 hours on RTX 4070 12GB with inductor
 
4
  # Use: python scripts/train.py training=medium
5
 
6
  dataloader:
7
- batch_size: 14
8
  shuffle: true
9
- num_workers: 6
10
  pin_memory: true
11
  persistent_workers: true
12
- prefetch_factor: 4
13
 
14
  optimizer:
15
  name: adamw
16
- lr: 3.0e-5
17
  weight_decay: 0.01
18
  eps: 1.0e-6
 
19
 
20
  scheduler:
21
  name: cosine
22
- warmup_steps: 300
23
 
24
  trainer:
25
- max_epochs: 3
26
  gradient_clip_norm: 1.0
27
- gradient_accumulation_steps: 3 # Effective batch = 42
28
  validation_max_length: 128
29
  label_smoothing: 0.1
30
  task_weights:
31
- summarization: 1.0
32
- emotion: 1.0
33
- topic: 1.0
34
- max_train_samples: 50000
35
- max_val_samples: 5000
 
 
 
 
 
 
 
 
 
1
  # Medium Configuration for FLAN-T5-base
2
  # Balanced approach - good results in reasonable time
3
+ # VRAM Usage: ~9-10GB peak (12GB available)
4
+ # Training time: ~45-60 minutes on RTX 4070 12GB with torch.compile
5
  # Use: python scripts/train.py training=medium
6
 
7
  dataloader:
8
+ batch_size: 6 # Conservative for 12GB VRAM with torch.compile
9
  shuffle: true
10
+ num_workers: 4
11
  pin_memory: true
12
  persistent_workers: true
13
+ prefetch_factor: 2
14
 
15
  optimizer:
16
  name: adamw
17
+ lr: 3.0e-5 # Balanced LR for quality
18
  weight_decay: 0.01
19
  eps: 1.0e-6
20
+ betas: [0.9, 0.999]
21
 
22
  scheduler:
23
  name: cosine
24
+ warmup_steps: 500 # ~2% warmup for 25k steps
25
 
26
  trainer:
27
+ max_epochs: 5 # More epochs for better convergence
28
  gradient_clip_norm: 1.0
29
+ gradient_accumulation_steps: 12 # Effective batch: 72 (6*12)
30
  validation_max_length: 128
31
  label_smoothing: 0.1
32
  task_weights:
33
+ summarization: 1.2 # Slightly prioritize summarization
34
+ emotion: 0.8
35
+ topic: 0.8
36
+ max_train_samples: 25000 # 25k samples - good balance
37
+ max_val_samples: 2500
38
+ early_stopping_patience: 3
39
+ log_grad_norm_frequency: 100
40
+
41
+ # Enable torch.compile for 1.5-2x speedup
42
+ compile_encoder: true
43
+ compile_decoder: true
44
+
45
+ tokenizer_max_length: 512
docs/api.md DELETED
@@ -1,79 +0,0 @@
1
- # API & CLI Documentation
2
-
3
- ## FastAPI Service
4
- The FastAPI application is defined in `src/api/app.py` and wires routes from
5
- `src/api/routes.py`. All dependencies resolve through `src/api/dependencies.py`, which lazily constructs the shared inference pipeline.
6
-
7
- ### POST `/summarize`
8
- - **Request Body** (`SummaryRequest`):
9
- ```json
10
- {
11
- "text": "Your input document"
12
- }
13
- ```
14
- - **Response** (`SummaryResponse`):
15
- ```json
16
- {
17
- "summary": "Generated abstractive summary",
18
- "emotion_labels": ["joy", "surprise"],
19
- "emotion_scores": [0.91, 0.63],
20
- "topic": "news",
21
- "topic_confidence": 0.82
22
- }
23
- ```
24
- - **Behaviour:**
25
- 1. Text is preprocessed through `TextPreprocessor` (with optional sklearn transformer if configured).
26
- 2. The multitask model generates a summary via greedy decoding.
27
- 3. Emotion and topic heads produce logits which are converted to probabilities and mapped to
28
- human-readable labels using `artifacts/labels.json`.
29
- 4. Results are returned as structured JSON suitable for a future Gradio interface.
30
-
31
- ### Error Handling
32
- - If the checkpoint or label metadata is missing, the dependency raises an HTTP 503 error with
33
- an explanatory message.
34
- - Validation errors (missing `text`) are handled automatically by FastAPI/Pydantic.
35
-
36
- ## Command-Line Interface
37
- `scripts/inference.py` provides a CLI that mirrors the API behaviour.
38
-
39
- ### Usage
40
- ```bash
41
- python scripts/inference.py "Document to analyse" \
42
- --checkpoint checkpoints/best.pt \
43
- --labels artifacts/labels.json \
44
- --tokenizer artifacts/hf_tokenizer \
45
- --model-config configs/model/base.yaml \
46
- --device cpu
47
- ```
48
-
49
- Options:
50
- - `text` – zero or more positional arguments. If omitted, use `--file` to point to a newline
51
- delimited text file.
52
- - `--file` – optional path containing one text per line.
53
- - `--checkpoint` – path to the trained model weights.
54
- - `--labels` – JSON containing emotion/topic vocabularies (defaults to `artifacts/labels.json`).
55
- - `--tokenizer` – optional tokenizer directory; defaults to the exported artifact if present.
56
- - `--model-config` – YAML describing the architecture.
57
- - `--device` – `cpu` or `cuda`. Passing `cuda` attempts to run inference on GPU.
58
- - `--summary-max-length` – overrides the default maximum generation length.
59
-
60
- ### Output
61
- The CLI prints a JSON array where each entry contains the original text, summary, emotion labels
62
- with scores, and topic prediction. This format is identical to the REST response, facilitating
63
- integration tests and future Gradio UI rendering.
64
-
65
- ## Future Gradio UI
66
- - The planned UI will call the same inference pipeline and display results interactively.
67
- - Given the response schema, the UI can show:
68
- - Generated summary text.
69
- - Emotion chips with probability bars.
70
- - Topic confidence gauges.
71
- - Placeholder panel for attention heatmaps and explanations.
72
- - Once implemented, documentation updates will add a `docs/ui.md` section and screenshots under
73
- `docs/images/`.
74
-
75
- ## Testing
76
- - `tests/test_api/test_routes.py` stubs the pipeline to ensure response fields and dependency
77
- overrides behave as expected.
78
- - `tests/test_inference/test_pipeline.py` validates pipeline methods end-to-end with dummy models,
79
- guaranteeing API and CLI consumers receive consistent payload shapes.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/architecture.md CHANGED
@@ -1,6 +1,7 @@
1
  # LexiMind Architecture
2
 
3
  ## Overview
 
4
  LexiMind couples a from-scratch Transformer implementation with a modern data and inference stack. The project consists of three major layers:
5
 
6
  1. **Data & Preprocessing** – lightweight text cleaning built on top of scikit-learn
@@ -15,6 +16,7 @@ LexiMind couples a from-scratch Transformer implementation with a modern data an
15
  The custom Transformer is designed with **modern architectural choices** while maintaining compatibility with pre-trained weights from Google's **FLAN-T5**.
16
 
17
  ### Architecture Highlights
 
18
  - **Pre-Layer Normalization (Pre-LN):** RMSNorm applied *before* each sublayer for stable training
19
  - **RMSNorm:** More efficient than LayerNorm (no mean computation, no bias parameters)
20
  - **FlashAttention:** Via PyTorch 2.0's `F.scaled_dot_product_attention` for O(N) memory
@@ -22,7 +24,9 @@ The custom Transformer is designed with **modern architectural choices** while m
22
  - **Multi-Head Attention:** 12 heads with optional LoRA adapters and RoPE support
23
 
24
  ### Weight Loading from FLAN-T5
 
25
  The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible Pre-LN architecture:
 
26
  - **Token embeddings:** Shared between encoder and decoder
27
  - **Attention projections:** Q, K, V, O weights (bias initialized to zero since T5 has no attention bias)
28
  - **FFN weights:** `wi_1` → `linear1`, `wo` → `linear2` (T5 uses gated FFN; we use the up/down projections)
@@ -32,6 +36,7 @@ The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible
32
  **Note:** T5 uses *relative position bias* computed in attention, not absolute embeddings. Our learned positional embeddings are randomly initialized and train quickly during fine-tuning.
33
 
34
  ### File Structure
 
35
  - `src/models/encoder.py` – TransformerEncoder with Pre-LN RMSNorm blocks
36
  - `src/models/decoder.py` – TransformerDecoder with KV-cache for efficient generation
37
  - `src/models/attention.py` – Multi-Head Attention with FlashAttention, LoRA, and RoPE support
@@ -40,16 +45,19 @@ The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible
40
  - `src/models/factory.py` – Builds models and loads FLAN-T5 weights
41
 
42
  ## Data, Tokenization, and Preprocessing
 
43
  - `src/data/tokenization.py` wraps `AutoTokenizer` (configured for FLAN-T5) to provide tensor-aware batching and helper utilities for decoder input shifting.
44
  - `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with optional scikit-learn transformers.
45
  - `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and collators.
46
 
47
  ### T5 Tokenizer Differences
 
48
  - **Vocab size:** 32,128 tokens (SentencePiece)
49
  - **Special tokens:** pad=0, eos=1 (no explicit BOS; decoder starts with pad token)
50
  - **Subword tokenization:** Unigram-based (vs BART's BPE)
51
 
52
  ## Training Pipeline
 
53
  - `src/training/trainer.py` coordinates multi-task optimization with:
54
  - Mixed precision training (bfloat16 on Ampere/Ada GPUs)
55
  - Gradient accumulation for larger effective batch sizes
@@ -58,12 +66,14 @@ The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible
58
  - Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and ROUGE-like overlap
59
 
60
  ## Inference & Serving
 
61
  - `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic.
62
  - `src/inference/factory.py` rebuilds the full pipeline using the exported tokenizer artifact
63
  - The CLI (`scripts/inference.py`) drives the pipeline from the command line
64
  - Gradio demo (`scripts/demo_gradio.py`) provides a web interface
65
 
66
  ## Key Decisions
 
67
  - **Custom Transformer + Pre-trained Weights:** Building from scratch demonstrates deep understanding while leveraging FLAN-T5's language knowledge
68
  - **Pre-LN RMSNorm:** Modern architecture used by LLaMA, T5 v1.1, and other 2023-2025 models
69
  - **Tokenizer Artifact Preference:** Inference favors `artifacts/hf_tokenizer` for reproducibility
 
1
  # LexiMind Architecture
2
 
3
  ## Overview
4
+
5
  LexiMind couples a from-scratch Transformer implementation with a modern data and inference stack. The project consists of three major layers:
6
 
7
  1. **Data & Preprocessing** – lightweight text cleaning built on top of scikit-learn
 
16
  The custom Transformer is designed with **modern architectural choices** while maintaining compatibility with pre-trained weights from Google's **FLAN-T5**.
17
 
18
  ### Architecture Highlights
19
+
20
  - **Pre-Layer Normalization (Pre-LN):** RMSNorm applied *before* each sublayer for stable training
21
  - **RMSNorm:** More efficient than LayerNorm (no mean computation, no bias parameters)
22
  - **FlashAttention:** Via PyTorch 2.0's `F.scaled_dot_product_attention` for O(N) memory
 
24
  - **Multi-Head Attention:** 12 heads with optional LoRA adapters and RoPE support
25
 
26
  ### Weight Loading from FLAN-T5
27
+
28
  The `factory.py` module loads weights from FLAN-T5-base, which uses a compatible Pre-LN architecture:
29
+
30
  - **Token embeddings:** Shared between encoder and decoder
31
  - **Attention projections:** Q, K, V, O weights (bias initialized to zero since T5 has no attention bias)
32
  - **FFN weights:** `wi_1` → `linear1`, `wo` → `linear2` (T5 uses gated FFN; we use the up/down projections)
 
36
  **Note:** T5 uses *relative position bias* computed in attention, not absolute embeddings. Our learned positional embeddings are randomly initialized and train quickly during fine-tuning.
37
 
38
  ### File Structure
39
+
40
  - `src/models/encoder.py` – TransformerEncoder with Pre-LN RMSNorm blocks
41
  - `src/models/decoder.py` – TransformerDecoder with KV-cache for efficient generation
42
  - `src/models/attention.py` – Multi-Head Attention with FlashAttention, LoRA, and RoPE support
 
45
  - `src/models/factory.py` – Builds models and loads FLAN-T5 weights
46
 
47
  ## Data, Tokenization, and Preprocessing
48
+
49
  - `src/data/tokenization.py` wraps `AutoTokenizer` (configured for FLAN-T5) to provide tensor-aware batching and helper utilities for decoder input shifting.
50
  - `src/data/preprocessing.py` introduces `TextPreprocessor`, layering a `BasicTextCleaner` with optional scikit-learn transformers.
51
  - `src/data/dataset.py` and `src/data/dataloader.py` define strongly typed dataset containers and collators.
52
 
53
  ### T5 Tokenizer Differences
54
+
55
  - **Vocab size:** 32,128 tokens (SentencePiece)
56
  - **Special tokens:** pad=0, eos=1 (no explicit BOS; decoder starts with pad token)
57
  - **Subword tokenization:** Unigram-based (vs BART's BPE)
58
 
59
  ## Training Pipeline
60
+
61
  - `src/training/trainer.py` coordinates multi-task optimization with:
62
  - Mixed precision training (bfloat16 on Ampere/Ada GPUs)
63
  - Gradient accumulation for larger effective batch sizes
 
66
  - Metrics in `src/training/metrics.py` include accuracy, multi-label F1, and ROUGE-like overlap
67
 
68
  ## Inference & Serving
69
+
70
  - `src/inference/pipeline.py` exposes summarization, emotion, and topic predictions with shared pre-processing, generation, and thresholding logic.
71
  - `src/inference/factory.py` rebuilds the full pipeline using the exported tokenizer artifact
72
  - The CLI (`scripts/inference.py`) drives the pipeline from the command line
73
  - Gradio demo (`scripts/demo_gradio.py`) provides a web interface
74
 
75
  ## Key Decisions
76
+
77
  - **Custom Transformer + Pre-trained Weights:** Building from scratch demonstrates deep understanding while leveraging FLAN-T5's language knowledge
78
  - **Pre-LN RMSNorm:** Modern architecture used by LLaMA, T5 v1.1, and other 2023-2025 models
79
  - **Tokenizer Artifact Preference:** Inference favors `artifacts/hf_tokenizer` for reproducibility
docs/training.md DELETED
@@ -1,80 +0,0 @@
1
- # Training Procedure
2
-
3
- ## Data Sources
4
- - **Summarization** – expects JSONL files with `source` and `summary` fields under
5
- `data/processed/summarization`.
6
- - **Emotion Classification** – multi-label samples loaded from JSONL files with
7
- `text` and `emotions` arrays. The dataset owns a `MultiLabelBinarizer` for consistent encoding.
8
- - **Topic Classification** – single-label categorical samples with `text` and `topic` fields, encoded via `LabelEncoder`.
9
-
10
- Paths and tokenizer defaults are configured in `configs/data/datasets.yaml`. The tokenizer section chooses the Hugging Face backbone (`google/flan-t5-base` by default) and maximum length. Gutenberg book downloads are controlled via the `downloads.books` list (each entry includes `name`, `url`, and `output`).
11
-
12
- ## Dataloaders & Collators
13
- - `SummarizationCollator` encodes encoder/decoder inputs, prepares decoder input IDs via `Tokenizer.prepare_decoder_inputs`, and masks padding tokens with `-100` for loss computation. Note: FLAN-T5 uses `pad_token_id=0` and `decoder_start_token_id=0`.
14
- - `EmotionCollator` applies the dataset's `MultiLabelBinarizer`, returning dense float tensors suitable for `BCEWithLogitsLoss`.
15
- - `TopicCollator` emits integer class IDs via the dataset's `LabelEncoder` for `CrossEntropyLoss`.
16
-
17
- These collators keep all tokenization centralized, reducing duplication and making it easy to swap in additional sklearn transformations through `TextPreprocessor` should we wish to extend cleaning or normalization.
18
-
19
- ## Model Assembly
20
- - `src/models/factory.build_multitask_model` rebuilds the encoder, decoder, and heads from the tokenizer metadata and YAML config. This factory is used both during training and inference to eliminate drift between environments.
21
- - Pretrained weights are loaded from FLAN-T5 using `_load_t5_weights()`, which transfers:
22
- - Shared token embeddings (with proper scaling)
23
- - Attention projections (q, k, v, o) for all encoder/decoder layers
24
- - FFN weights (wi_0, wi_1 for gated activation, wo for output)
25
- - Layer normalization parameters (mapped from T5's RMSNorm)
26
- - The model wraps:
27
- - Transformer encoder/decoder stacks with **Pre-LN RMSNorm** architecture.
28
- - LM head tied to decoder embeddings for summarization.
29
- - Mean-pooled classification heads for emotion and topic tasks.
30
-
31
- ## Optimisation Loop
32
- - `src/training/trainer.Trainer` orchestrates multi-task training.
33
- - Cross-entropy is used for summarization (seq2seq logits vs. shifted labels).
34
- - `BCEWithLogitsLoss` handles multi-label emotions.
35
- - `CrossEntropyLoss` handles topic classification.
36
- - Gradient clipping ensures stability, and per-task weights can be configured via
37
- `TrainerConfig.task_weights` to balance gradients if needed.
38
- - Metrics tracked per task:
39
- - **Summarization** – ROUGE-like overlap metric (`training.metrics.rouge_like`).
40
- - **Emotion** – micro F1 score for multi-label predictions.
41
- - **Topic** – categorical accuracy.
42
-
43
- ## Checkpoints & Artifacts
44
- - `src/utils/io.save_state` stores model weights; checkpoints live under `checkpoints/`.
45
- - `artifacts/labels.json` captures the ordered emotion/topic vocabularies immediately after
46
- training. This file is required for inference so class indices map back to human-readable labels.
47
- - The tokenizer is exported to `artifacts/hf_tokenizer/` for reproducible vocabularies using `scripts/export_tokenizer.py`.
48
-
49
- ## Running Training
50
- 1. Ensure processed datasets are available (see `data/processed/` structure).
51
- 2. Export the FLAN-T5 tokenizer: `python scripts/export_tokenizer.py`
52
- 3. Choose a configuration (e.g., `configs/training/dev.yaml`) for hyperparameters and data splits.
53
- 4. Instantiate the tokenizer via `TokenizerConfig` and build datasets/dataloaders.
54
- 5. Use `build_multitask_model` to construct the model with FLAN-T5 weights, create an optimizer, and run
55
- `Trainer.fit(train_loaders, val_loaders)`.
56
- 6. Save checkpoints and update `artifacts/labels.json` with the dataset label order.
57
-
58
- ```bash
59
- # Quick start
60
- python scripts/export_tokenizer.py # Export FLAN-T5 tokenizer
61
- python scripts/train.py training=dev # Run dev training (2 epochs)
62
- python scripts/train.py training=medium # Run medium training (5 epochs)
63
- python scripts/train.py training=full # Run full training (10 epochs)
64
- ```
65
-
66
- ## Why FLAN-T5?
67
- LexiMind's custom Transformer uses **Pre-LN (normalization before sublayers)** with **RMSNorm**. This modern architecture choice provides:
68
- - Better gradient flow during training
69
- - Improved training stability
70
- - Faster convergence
71
-
72
- FLAN-T5 uses the same Pre-LN RMSNorm architecture, making weight transfer straightforward. Previously used BART (Post-LN LayerNorm) had a fundamental architectural mismatch that caused training issues.
73
-
74
- > **Note:** T5's relative position bias is NOT transferred. The model uses learned positional encodings which train from scratch. This is fine since positional information is task-specific.
75
-
76
- ## Future Enhancements
77
- - Integrate curriculum scheduling or task-balanced sampling once empirical results dictate.
78
- - Capture attention maps during training to support visualization in the planned Gradio UI.
79
- - Leverage the optional `sklearn_transformer` hook in `TextPreprocessor` for lemmatization or domain-specific normalization when datasets require it.
80
- - Experiment with FLAN-T5-large for improved performance on longer sequences.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
outputs/training_history.json CHANGED
@@ -1,59 +1,59 @@
1
  {
2
- "train_epoch_1": {
3
- "summarization_loss": 5.035701327740604,
4
- "summarization_rouge_like": 0.16390742100245742,
5
- "emotion_loss": 0.21049204537547025,
6
- "emotion_f1": 0.002655381929628719,
7
- "topic_loss": 1.176912516972419,
8
- "topic_accuracy": 0.6581478229164939,
9
- "total_loss": 6.423106049642868,
10
- "epoch": 1.0
11
  },
12
- "val_epoch_1": {
13
- "summarization_loss": 4.6882993674363105,
14
- "summarization_rouge_like": 0.19405199466966144,
15
- "emotion_loss": 0.15183634538985658,
16
- "emotion_f1": 0.0016098967067287486,
17
- "topic_loss": 0.8788343331143526,
18
- "topic_accuracy": 0.7251652262328394,
19
- "epoch": 1.0
20
  },
21
- "train_epoch_2": {
22
- "summarization_loss": 4.561023824777751,
23
- "summarization_rouge_like": 0.20945581532076613,
24
- "emotion_loss": 0.14958151845580364,
25
- "emotion_f1": 0.008022325540815077,
26
- "topic_loss": 0.8585619787599033,
27
- "topic_accuracy": 0.7299605100316837,
28
- "total_loss": 5.569167470253677,
29
- "epoch": 2.0
30
  },
31
- "val_epoch_2": {
32
- "summarization_loss": 4.335443331423179,
33
- "summarization_rouge_like": 0.2383154143354784,
34
- "emotion_loss": 0.1478777239331147,
35
- "emotion_f1": 0.010150822387259202,
36
- "topic_loss": 0.841049696600522,
37
- "topic_accuracy": 0.7359938993390932,
38
- "epoch": 2.0
39
  },
40
- "train_epoch_3": {
41
- "summarization_loss": 4.332563984521343,
42
- "summarization_rouge_like": 0.24358268281949097,
43
- "emotion_loss": 0.14520242059475916,
44
- "emotion_f1": 0.026584760984350638,
45
- "topic_loss": 0.8084657974773926,
46
- "topic_accuracy": 0.7434995609372882,
47
- "total_loss": 5.286232347914138,
48
- "epoch": 3.0
49
  },
50
- "val_epoch_3": {
51
- "summarization_loss": 4.0994785383502785,
52
- "summarization_rouge_like": 0.2839536633314319,
53
- "emotion_loss": 0.14214695994858215,
54
- "emotion_f1": 0.028164719230763854,
55
- "topic_loss": 0.8218616072552484,
56
- "topic_accuracy": 0.7413319776309091,
57
- "epoch": 3.0
58
  }
59
  }
 
1
  {
2
+ "train_epoch_6": {
3
+ "summarization_loss": 3.2071112584752606,
4
+ "summarization_rouge_like": 0.41666206128984185,
5
+ "emotion_loss": 0.13381094067425187,
6
+ "emotion_f1": 0.1527181073975268,
7
+ "topic_loss": 0.6847172836312407,
8
+ "topic_accuracy": 0.7834830254758819,
9
+ "total_loss": 5.492251664781721,
10
+ "epoch": 6.0
11
  },
12
+ "val_epoch_6": {
13
+ "summarization_loss": 2.988837990901862,
14
+ "summarization_rouge_like": 0.4475286348323649,
15
+ "emotion_loss": 0.1262940275061054,
16
+ "emotion_f1": 0.19359053170564663,
17
+ "topic_loss": 0.7910004459155627,
18
+ "topic_accuracy": 0.754854122191724,
19
+ "epoch": 6.0
20
  },
21
+ "train_epoch_7": {
22
+ "summarization_loss": 3.184010818695097,
23
+ "summarization_rouge_like": 0.41903763419721,
24
+ "emotion_loss": 0.12498181367997213,
25
+ "emotion_f1": 0.2043521878681856,
26
+ "topic_loss": 0.6483695249464139,
27
+ "topic_accuracy": 0.796684177822936,
28
+ "total_loss": 5.419693668500609,
29
+ "epoch": 7.0
30
  },
31
+ "val_epoch_7": {
32
+ "summarization_loss": 2.985372142407835,
33
+ "summarization_rouge_like": 0.44758863369550994,
34
+ "emotion_loss": 0.1185748163268729,
35
+ "emotion_f1": 0.2514045691051182,
36
+ "topic_loss": 0.7817700606483663,
37
+ "topic_accuracy": 0.7554132357426027,
38
+ "epoch": 7.0
39
  },
40
+ "train_epoch_8": {
41
+ "summarization_loss": 3.171688149997974,
42
+ "summarization_rouge_like": 0.4206951155149097,
43
+ "emotion_loss": 0.12107599671589805,
44
+ "emotion_f1": 0.2286830931525678,
45
+ "topic_loss": 0.6216138880150013,
46
+ "topic_accuracy": 0.8049539626051729,
47
+ "total_loss": 5.375899340986727,
48
+ "epoch": 8.0
49
  },
50
+ "val_epoch_8": {
51
+ "summarization_loss": 2.984391659270994,
52
+ "summarization_rouge_like": 0.44770155741256373,
53
+ "emotion_loss": 0.11704520378562873,
54
+ "emotion_f1": 0.26809326239605075,
55
+ "topic_loss": 0.7841400383105634,
56
+ "topic_accuracy": 0.7546508081732227,
57
+ "epoch": 8.0
58
  }
59
  }
scripts/demo_gradio.py CHANGED
@@ -4,10 +4,10 @@ Gradio demo for LexiMind multi-task NLP model.
4
  Showcases the model's capabilities across three tasks:
5
  - Summarization: Generates concise summaries of input text
6
  - Emotion Detection: Multi-label emotion classification
7
- - Topic Classification: Categorizes text into news topics
8
 
9
  Author: Oliver Perrin
10
- Date: 2025-12-04
11
  """
12
 
13
  from __future__ import annotations
@@ -38,24 +38,12 @@ logger = get_logger(__name__)
38
 
39
  OUTPUTS_DIR = PROJECT_ROOT / "outputs"
40
  EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
 
41
 
42
  SAMPLE_TEXTS = [
43
- (
44
- "Artificial intelligence is rapidly transforming technology. "
45
- "Machine learning algorithms process vast amounts of data, identifying "
46
- "patterns with unprecedented accuracy. From healthcare to finance, AI is "
47
- "revolutionizing industries worldwide."
48
- ),
49
- (
50
- "The team's incredible comeback in the final quarter left fans in tears of joy. "
51
- "After trailing by 20 points, they scored three consecutive touchdowns to secure "
52
- "their first championship victory in over a decade."
53
- ),
54
- (
55
- "Global markets tumbled today as investors reacted to rising inflation concerns. "
56
- "The Federal Reserve hinted at potential interest rate hikes, sending shockwaves "
57
- "through technology and banking sectors."
58
- ),
59
  ]
60
 
61
  # --------------- Pipeline Management ---------------
@@ -94,27 +82,62 @@ def get_pipeline():
94
  def analyze(text: str) -> tuple[str, str, str]:
95
  """Run all three tasks and return formatted results."""
96
  if not text or not text.strip():
97
- return "Enter text above", "", ""
98
 
99
  try:
100
  pipe = get_pipeline()
101
 
102
  # Run tasks
103
- summary = pipe.summarize([text], max_length=128)[0].strip() or "(empty)"
104
- emotions = pipe.predict_emotions([text], threshold=0.5)[0]
 
 
 
105
  topic = pipe.predict_topics([text])[0]
106
 
107
- # Format emotions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if emotions.labels:
109
- emotion_str = " • ".join(
110
- f"**{lbl}** ({score:.0%})"
111
- for lbl, score in zip(emotions.labels, emotions.scores, strict=True)
112
- )
 
113
  else:
114
- emotion_str = "No strong emotions detected"
115
 
116
  # Format topic
117
- topic_str = f"**{topic.label}** ({topic.confidence:.0%})"
118
 
119
  return summary, emotion_str, topic_str
120
 
@@ -125,75 +148,138 @@ def analyze(text: str) -> tuple[str, str, str]:
125
 
126
  def load_metrics() -> str:
127
  """Load evaluation metrics and format as markdown."""
128
- if not EVAL_REPORT_PATH.exists():
129
- return "No evaluation report found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- try:
132
- with open(EVAL_REPORT_PATH) as f:
133
- r = json.load(f)
 
 
 
 
 
 
 
 
134
 
135
- return f"""
136
- ### Overall Performance
 
 
 
 
 
137
 
138
- | Task | Metric | Score |
139
- |------|--------|-------|
140
- | **Emotion** | F1 Macro | **{r["emotion"]["f1_macro"]:.1%}** |
141
- | **Topic** | Accuracy | **{r["topic"]["accuracy"]:.1%}** |
142
- | **Summarization** | ROUGE-Like | {r["summarization"]["rouge_like"]:.1%} |
143
- | **Summarization** | BLEU | {r["summarization"]["bleu"]:.1%} |
144
 
145
- ### Topic Classification (per-class)
146
 
147
- | Category | Precision | Recall | F1 |
148
- |----------|-----------|--------|-----|
149
- | Business | {r["topic"]["classification_report"]["Business"]["precision"]:.1%} | {r["topic"]["classification_report"]["Business"]["recall"]:.1%} | {r["topic"]["classification_report"]["Business"]["f1-score"]:.1%} |
150
- | Sci/Tech | {r["topic"]["classification_report"]["Sci/Tech"]["precision"]:.1%} | {r["topic"]["classification_report"]["Sci/Tech"]["recall"]:.1%} | {r["topic"]["classification_report"]["Sci/Tech"]["f1-score"]:.1%} |
151
- | Sports | {r["topic"]["classification_report"]["Sports"]["precision"]:.1%} | {r["topic"]["classification_report"]["Sports"]["recall"]:.1%} | {r["topic"]["classification_report"]["Sports"]["f1-score"]:.1%} |
152
- | World | {r["topic"]["classification_report"]["World"]["precision"]:.1%} | {r["topic"]["classification_report"]["World"]["recall"]:.1%} | {r["topic"]["classification_report"]["World"]["f1-score"]:.1%} |
153
- """
154
- except Exception as e:
155
- return f"Error loading metrics: {e}"
156
 
157
 
158
  # --------------- Gradio Interface ---------------
159
 
160
  with gr.Blocks(
161
- title="LexiMind Demo",
162
  theme=gr.themes.Soft(),
163
- css=".output-box { min-height: 80px; }",
164
  ) as demo:
165
  gr.Markdown(
166
  """
167
  # 🧠 LexiMind
168
  ### Multi-Task Transformer for Document Analysis
169
 
170
- A custom encoder-decoder Transformer trained on summarization, emotion detection,
171
- and topic classification. Built from scratch with PyTorch.
 
 
172
  """
173
  )
174
 
175
  # --------------- Try It Tab ---------------
176
  with gr.Tab("🚀 Try It"):
177
  with gr.Row():
178
- with gr.Column(scale=2):
179
  text_input = gr.Textbox(
180
- label="Input Text",
181
- lines=5,
182
- placeholder="Enter text to analyze...",
183
  value=SAMPLE_TEXTS[0],
184
  )
 
 
 
 
 
 
 
185
  with gr.Row():
186
- analyze_btn = gr.Button("Analyze", variant="primary", scale=2)
187
- gr.Examples(
188
- examples=[[t] for t in SAMPLE_TEXTS],
189
- inputs=text_input,
190
- label="Examples",
191
- )
 
192
 
193
  with gr.Column(scale=2):
194
- summary_out = gr.Textbox(label="📝 Summary", lines=3, elem_classes="output-box")
195
- emotion_out = gr.Markdown(label="😊 Emotions")
196
- topic_out = gr.Markdown(label="📂 Topic")
 
 
 
 
 
 
 
 
 
 
197
 
198
  analyze_btn.click(
199
  fn=analyze,
@@ -203,9 +289,35 @@ with gr.Blocks(
203
 
204
  # --------------- Metrics Tab ---------------
205
  with gr.Tab("📊 Metrics"):
206
- gr.Markdown(load_metrics())
207
- gr.Markdown("### Confusion Matrix")
208
- gr.Image(str(OUTPUTS_DIR / "topic_confusion_matrix.png"), label="Topic Classification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # --------------- Architecture Tab ---------------
211
  with gr.Tab("🔧 Architecture"):
@@ -213,28 +325,34 @@ with gr.Blocks(
213
  """
214
  ### Model Architecture
215
 
216
- - **Base**: Custom Transformer (encoder-decoder)
217
- - **Initialized from**: FLAN-T5-base weights
218
- - **Encoder**: 6 layers, 768 hidden dim, 12 attention heads
219
- - **Decoder**: 6 layers with cross-attention
220
- - **Task Heads**: Classification heads for emotion/topic
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- ### Training
223
 
224
- - **Optimizer**: AdamW with cosine LR schedule
225
- - **Mixed Precision**: bfloat16 with TF32
226
- - **Compilation**: torch.compile with inductor backend
 
 
227
  """
228
  )
229
- with gr.Row():
230
- gr.Image(
231
- str(OUTPUTS_DIR / "attention_visualization.png"),
232
- label="Self-Attention Pattern",
233
- )
234
- gr.Image(
235
- str(OUTPUTS_DIR / "positional_encoding_heatmap.png"),
236
- label="Positional Encodings",
237
- )
238
 
239
  # --------------- About Tab ---------------
240
  with gr.Tab("ℹ️ About"):
@@ -242,22 +360,28 @@ with gr.Blocks(
242
  """
243
  ### About LexiMind
244
 
245
- LexiMind is a multi-task NLP model designed to demonstrate end-to-end
246
- machine learning engineering skills:
 
 
 
 
 
 
 
247
 
248
- - **Custom Transformer** implementation from scratch
249
- - **Multi-task learning** with shared encoder
250
- - **Production-ready** inference pipeline
251
- - **Comprehensive evaluation** with multiple metrics
252
 
253
  ### Links
254
 
255
  - 🔗 [GitHub Repository](https://github.com/OliverPerrin/LexiMind)
256
- - 🤗 [HuggingFace Space](https://huggingface.co/spaces/OliverPerrin/LexiMind)
257
 
258
- ### Author
259
 
260
- **Oliver Perrin** - Machine Learning Engineer
261
  """
262
  )
263
 
 
4
  Showcases the model's capabilities across three tasks:
5
  - Summarization: Generates concise summaries of input text
6
  - Emotion Detection: Multi-label emotion classification
7
+ - Topic Classification: Categorizes text into topics
8
 
9
  Author: Oliver Perrin
10
+ Date: 2025-12-05
11
  """
12
 
13
  from __future__ import annotations
 
38
 
39
  OUTPUTS_DIR = PROJECT_ROOT / "outputs"
40
  EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
41
+ TRAINING_HISTORY_PATH = OUTPUTS_DIR / "training_history.json"
42
 
43
  SAMPLE_TEXTS = [
44
+ "Global markets tumbled today as investors reacted to rising inflation concerns. The Federal Reserve hinted at potential interest rate hikes, sending shockwaves through technology and banking sectors. Analysts predict continued volatility as economic uncertainty persists.",
45
+ "Scientists at MIT have developed a breakthrough quantum computing chip that operates at room temperature. This advancement could revolutionize drug discovery, cryptography, and artificial intelligence. The research team published their findings in Nature.",
46
+ "The championship game ended in dramatic fashion as the underdog team scored in the final seconds to secure victory. Fans rushed the field in celebration, marking the team's first title in 25 years.",
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ]
48
 
49
  # --------------- Pipeline Management ---------------
 
82
  def analyze(text: str) -> tuple[str, str, str]:
83
  """Run all three tasks and return formatted results."""
84
  if not text or not text.strip():
85
+ return "Please enter text above to analyze.", "", ""
86
 
87
  try:
88
  pipe = get_pipeline()
89
 
90
  # Run tasks
91
+ summary = pipe.summarize([text], max_length=128)[0].strip()
92
+ if not summary:
93
+ summary = "(Unable to generate summary)"
94
+
95
+ emotions = pipe.predict_emotions([text], threshold=0.3)[0] # Lower threshold
96
  topic = pipe.predict_topics([text])[0]
97
 
98
+ # Format emotions with emoji
99
+ emotion_emoji = {
100
+ "joy": "😊",
101
+ "love": "❤️",
102
+ "anger": "😠",
103
+ "fear": "😨",
104
+ "sadness": "😢",
105
+ "surprise": "😲",
106
+ "neutral": "😐",
107
+ "admiration": "🤩",
108
+ "amusement": "😄",
109
+ "annoyance": "😤",
110
+ "approval": "👍",
111
+ "caring": "🤗",
112
+ "confusion": "😕",
113
+ "curiosity": "🤔",
114
+ "desire": "😍",
115
+ "disappointment": "😞",
116
+ "disapproval": "👎",
117
+ "disgust": "🤢",
118
+ "embarrassment": "😳",
119
+ "excitement": "🎉",
120
+ "gratitude": "🙏",
121
+ "grief": "😭",
122
+ "nervousness": "��",
123
+ "optimism": "🌟",
124
+ "pride": "🦁",
125
+ "realization": "💡",
126
+ "relief": "😌",
127
+ "remorse": "😔",
128
+ }
129
+
130
  if emotions.labels:
131
+ emotion_parts = []
132
+ for lbl, score in zip(emotions.labels[:5], emotions.scores[:5], strict=False):
133
+ emoji = emotion_emoji.get(lbl.lower(), "•")
134
+ emotion_parts.append(f"{emoji} **{lbl.title()}** ({score:.0%})")
135
+ emotion_str = "\n".join(emotion_parts)
136
  else:
137
+ emotion_str = "😐 No strong emotions detected"
138
 
139
  # Format topic
140
+ topic_str = f"**{topic.label}**\n\nConfidence: {topic.confidence:.0%}"
141
 
142
  return summary, emotion_str, topic_str
143
 
 
148
 
149
  def load_metrics() -> str:
150
  """Load evaluation metrics and format as markdown."""
151
+ # Load evaluation report
152
+ eval_metrics = {}
153
+ if EVAL_REPORT_PATH.exists():
154
+ try:
155
+ with open(EVAL_REPORT_PATH) as f:
156
+ eval_metrics = json.load(f)
157
+ except Exception:
158
+ pass
159
+
160
+ # Load training history
161
+ train_metrics = {}
162
+ if TRAINING_HISTORY_PATH.exists():
163
+ try:
164
+ with open(TRAINING_HISTORY_PATH) as f:
165
+ train_metrics = json.load(f)
166
+ except Exception:
167
+ pass
168
+
169
+ # Get final validation metrics
170
+ val_final = train_metrics.get("val_epoch_3", {})
171
+
172
+ md = """
173
+ ## 📈 Model Performance
174
+
175
+ ### Training Results (3 Epochs)
176
+
177
+ | Task | Metric | Final Score |
178
+ |------|--------|-------------|
179
+ | **Topic Classification** | Accuracy | **{topic_acc:.1%}** |
180
+ | **Emotion Detection** | F1 (training) | {emo_f1:.1%} |
181
+ | **Summarization** | ROUGE-like | {rouge:.1%} |
182
+
183
+ ### Evaluation Results
184
+
185
+ | Metric | Value |
186
+ |--------|-------|
187
+ | Topic Accuracy | **{eval_topic:.1%}** |
188
+ | Emotion F1 (macro) | {eval_emo:.1%} |
189
+ | ROUGE-like | {eval_rouge:.1%} |
190
+ | BLEU | {eval_bleu:.3f} |
191
+
192
+ ---
193
+
194
+ ### Topic Classification Details
195
 
196
+ | Category | Precision | Recall | F1 |
197
+ |----------|-----------|--------|-----|
198
+ """.format(
199
+ topic_acc=val_final.get("topic_accuracy", 0),
200
+ emo_f1=val_final.get("emotion_f1", 0),
201
+ rouge=val_final.get("summarization_rouge_like", 0),
202
+ eval_topic=eval_metrics.get("topic", {}).get("accuracy", 0),
203
+ eval_emo=eval_metrics.get("emotion", {}).get("f1_macro", 0),
204
+ eval_rouge=eval_metrics.get("summarization", {}).get("rouge_like", 0),
205
+ eval_bleu=eval_metrics.get("summarization", {}).get("bleu", 0),
206
+ )
207
 
208
+ # Add per-class metrics
209
+ topic_report = eval_metrics.get("topic", {}).get("classification_report", {})
210
+ for cat, metrics in topic_report.items():
211
+ if cat in ["macro avg", "weighted avg", "micro avg"]:
212
+ continue
213
+ if isinstance(metrics, dict):
214
+ md += f"| {cat} | {metrics.get('precision', 0):.1%} | {metrics.get('recall', 0):.1%} | {metrics.get('f1-score', 0):.1%} |\n"
215
 
216
+ return md
 
 
 
 
 
217
 
 
218
 
219
+ def get_viz_path(filename: str) -> str | None:
220
+ """Get visualization path if file exists."""
221
+ path = OUTPUTS_DIR / filename
222
+ return str(path) if path.exists() else None
 
 
 
 
 
223
 
224
 
225
  # --------------- Gradio Interface ---------------
226
 
227
  with gr.Blocks(
228
+ title="LexiMind - Multi-Task NLP",
229
  theme=gr.themes.Soft(),
 
230
  ) as demo:
231
  gr.Markdown(
232
  """
233
  # 🧠 LexiMind
234
  ### Multi-Task Transformer for Document Analysis
235
 
236
+ A custom encoder-decoder Transformer trained on **summarization**, **emotion detection** (28 classes),
237
+ and **topic classification** (10 categories). Built from scratch with PyTorch.
238
+
239
+ > ⚠️ **Note**: Summarization is experimental - the model works best on news-style articles.
240
  """
241
  )
242
 
243
  # --------------- Try It Tab ---------------
244
  with gr.Tab("🚀 Try It"):
245
  with gr.Row():
246
+ with gr.Column(scale=3):
247
  text_input = gr.Textbox(
248
+ label="📝 Input Text",
249
+ lines=6,
250
+ placeholder="Enter or paste text to analyze (works best with news articles)...",
251
  value=SAMPLE_TEXTS[0],
252
  )
253
+ analyze_btn = gr.Button(
254
+ "🔍 Analyze",
255
+ variant="primary",
256
+ size="sm",
257
+ )
258
+
259
+ gr.Markdown("**Sample Texts** (click to use):")
260
  with gr.Row():
261
+ sample1_btn = gr.Button("📰 Markets", size="sm", variant="secondary")
262
+ sample2_btn = gr.Button("🔬 Science", size="sm", variant="secondary")
263
+ sample3_btn = gr.Button("🏆 Sports", size="sm", variant="secondary")
264
+
265
+ sample1_btn.click(fn=lambda: SAMPLE_TEXTS[0], outputs=text_input)
266
+ sample2_btn.click(fn=lambda: SAMPLE_TEXTS[1], outputs=text_input)
267
+ sample3_btn.click(fn=lambda: SAMPLE_TEXTS[2], outputs=text_input)
268
 
269
  with gr.Column(scale=2):
270
+ gr.Markdown("### Results")
271
+ summary_out = gr.Textbox(
272
+ label="📝 Summary",
273
+ lines=3,
274
+ interactive=False,
275
+ )
276
+ with gr.Row():
277
+ with gr.Column():
278
+ gr.Markdown("**😊 Emotions**")
279
+ emotion_out = gr.Markdown(value="*Run analysis*")
280
+ with gr.Column():
281
+ gr.Markdown("**📂 Topic**")
282
+ topic_out = gr.Markdown(value="*Run analysis*")
283
 
284
  analyze_btn.click(
285
  fn=analyze,
 
289
 
290
  # --------------- Metrics Tab ---------------
291
  with gr.Tab("📊 Metrics"):
292
+ with gr.Row():
293
+ with gr.Column(scale=2):
294
+ gr.Markdown(load_metrics())
295
+ with gr.Column(scale=1):
296
+ confusion_path = get_viz_path("topic_confusion_matrix.png")
297
+ if confusion_path:
298
+ gr.Image(confusion_path, label="Confusion Matrix", show_label=True)
299
+
300
+ # --------------- Visualizations Tab ---------------
301
+ with gr.Tab("🎨 Visualizations"):
302
+ gr.Markdown("### Model Internals")
303
+
304
+ with gr.Row():
305
+ attn_path = get_viz_path("attention_visualization.png")
306
+ if attn_path:
307
+ gr.Image(attn_path, label="Self-Attention Pattern")
308
+
309
+ pos_path = get_viz_path("positional_encoding_heatmap.png")
310
+ if pos_path:
311
+ gr.Image(pos_path, label="Positional Encodings")
312
+
313
+ with gr.Row():
314
+ multi_path = get_viz_path("multihead_attention_visualization.png")
315
+ if multi_path:
316
+ gr.Image(multi_path, label="Multi-Head Attention")
317
+
318
+ single_path = get_viz_path("single_vs_multihead.png")
319
+ if single_path:
320
+ gr.Image(single_path, label="Single vs Multi-Head Comparison")
321
 
322
  # --------------- Architecture Tab ---------------
323
  with gr.Tab("🔧 Architecture"):
 
325
  """
326
  ### Model Architecture
327
 
328
+ | Component | Configuration |
329
+ |-----------|---------------|
330
+ | **Base** | Custom Transformer (encoder-decoder) |
331
+ | **Initialization** | FLAN-T5-base weights |
332
+ | **Encoder** | 6 layers, 768 hidden dim, 12 heads |
333
+ | **Decoder** | 6 layers with cross-attention |
334
+ | **Activation** | Gated-GELU |
335
+ | **Position** | Relative position bias |
336
+
337
+ ### Training Configuration
338
+
339
+ | Setting | Value |
340
+ |---------|-------|
341
+ | **Optimizer** | AdamW (lr=2e-5, wd=0.01) |
342
+ | **Scheduler** | Cosine with 1000 warmup steps |
343
+ | **Batch Size** | 14 × 3 accumulation = 42 effective |
344
+ | **Precision** | TF32 (Ampere GPU) |
345
+ | **Compilation** | torch.compile (inductor) |
346
 
347
+ ### Datasets
348
 
349
+ | Task | Dataset | Size |
350
+ |------|---------|------|
351
+ | **Summarization** | CNN/DailyMail + BookSum | ~110K |
352
+ | **Emotion** | GoEmotions | ~43K (28 labels) |
353
+ | **Topic** | Yahoo Answers | ~200K (10 classes) |
354
  """
355
  )
 
 
 
 
 
 
 
 
 
356
 
357
  # --------------- About Tab ---------------
358
  with gr.Tab("ℹ️ About"):
 
360
  """
361
  ### About LexiMind
362
 
363
+ LexiMind is a **portfolio project** demonstrating end-to-end machine learning engineering:
364
+
365
+ ✅ Custom Transformer implementation from scratch
366
+ ✅ Multi-task learning with shared encoder
367
+ ✅ Production-ready inference pipeline
368
+ ✅ Comprehensive evaluation and visualization
369
+ ✅ CI/CD with GitHub Actions
370
+
371
+ ### Known Limitations
372
 
373
+ - **Summarization** quality is limited (needs more training epochs)
374
+ - **Emotion detection** has low F1 due to class imbalance in GoEmotions
375
+ - Best results on **news-style text** (training domain)
 
376
 
377
  ### Links
378
 
379
  - 🔗 [GitHub Repository](https://github.com/OliverPerrin/LexiMind)
380
+ - 🤗 [Model on HuggingFace](https://huggingface.co/OliverPerrin/LexiMind-Model)
381
 
382
+ ---
383
 
384
+ **Built by Oliver Perrin** | December 2025
385
  """
386
  )
387
 
scripts/download_data.py CHANGED
@@ -85,6 +85,59 @@ TOPIC_LABELS = [
85
  # --------------- Utility Functions ---------------
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def _write_jsonl(records: list[dict], destination: Path, desc: str = "Writing") -> None:
89
  """Write records to JSONL file with progress bar."""
90
  destination.parent.mkdir(parents=True, exist_ok=True)
 
85
  # --------------- Utility Functions ---------------
86
 
87
 
88
+ def _normalize_label(label: object, label_names: list[str]) -> str:
89
+ """Convert a label index or raw value into a string name.
90
+
91
+ - Valid integer indices are mapped to label_names.
92
+ - Everything else is stringified for robustness.
93
+ """
94
+
95
+ if isinstance(label, int) and 0 <= label < len(label_names):
96
+ return label_names[label]
97
+ return str(label)
98
+
99
+
100
+ def _emotion_records(dataset_split: Any, label_names: list[str]) -> list[dict[str, object]]:
101
+ """Yield emotion records with resilient label handling."""
102
+
103
+ records: list[dict[str, object]] = []
104
+ for row in dataset_split:
105
+ text = str(getattr(row, "text", None) or row.get("text", ""))
106
+ raw_labels = getattr(row, "label", None) or row.get("label") or row.get("labels", [])
107
+
108
+ # Normalize to list
109
+ if isinstance(raw_labels, list):
110
+ label_values = raw_labels
111
+ elif raw_labels is None:
112
+ label_values = []
113
+ else:
114
+ label_values = [raw_labels]
115
+
116
+ emotions = [_normalize_label(lbl, label_names) for lbl in label_values]
117
+ if text:
118
+ records.append({"text": text, "emotions": emotions})
119
+ return records
120
+
121
+
122
+ def _topic_records(dataset_split: Any, label_names: list[str]) -> list[dict[str, object]]:
123
+ """Yield topic records with resilient label handling."""
124
+
125
+ records: list[dict[str, object]] = []
126
+ for row in dataset_split:
127
+ text = str(getattr(row, "text", None) or row.get("text", ""))
128
+ raw_label = getattr(row, "label", None) or row.get("label") or row.get("topic")
129
+
130
+ if isinstance(raw_label, list):
131
+ label_value = raw_label[0] if raw_label else ""
132
+ else:
133
+ label_value = raw_label
134
+
135
+ topic = _normalize_label(label_value, label_names) if label_value is not None else ""
136
+ if text:
137
+ records.append({"text": text, "topic": topic})
138
+ return records
139
+
140
+
141
  def _write_jsonl(records: list[dict], destination: Path, desc: str = "Writing") -> None:
142
  """Write records to JSONL file with progress bar."""
143
  destination.parent.mkdir(parents=True, exist_ok=True)
scripts/process_books.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Process book collection with LexiMind model.
3
+
4
+ Analyzes each book to generate:
5
+ - Overall topic classification
6
+ - Dominant emotions
7
+ - Concise summary
8
+
9
+ Results are saved to data/processed/books/library.json for future use.
10
+
11
+ Author: Oliver Perrin
12
+ Date: December 2025
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
22
+ if str(PROJECT_ROOT) not in sys.path:
23
+ sys.path.insert(0, str(PROJECT_ROOT))
24
+
25
+ from src.inference.factory import create_inference_pipeline
26
+ from src.utils.logging import configure_logging, get_logger
27
+
28
+ configure_logging()
29
+ logger = get_logger(__name__)
30
+
31
+ # --------------- Configuration ---------------
32
+
33
+ BOOKS_DIR = PROJECT_ROOT / "data" / "raw" / "books"
34
+ OUTPUT_PATH = PROJECT_ROOT / "data" / "processed" / "books" / "library.json"
35
+
36
+ # Chunk books into manageable sections for analysis
37
+ MAX_CHUNK_LENGTH = 1000 # characters per chunk
38
+ MAX_CHUNKS = 5 # analyze first N chunks to get representative sample
39
+
40
+
41
+ # --------------- Book Processing ---------------
42
+
43
+
44
+ def clean_text(text: str) -> str:
45
+ """Clean and normalize book text."""
46
+ # Remove Project Gutenberg headers/footers (common patterns)
47
+ lines = text.split("\n")
48
+ start_idx = 0
49
+ end_idx = len(lines)
50
+
51
+ for i, line in enumerate(lines):
52
+ if "START OF" in line.upper() and "PROJECT GUTENBERG" in line.upper():
53
+ start_idx = i + 1
54
+ break
55
+
56
+ for i in range(len(lines) - 1, -1, -1):
57
+ if "END OF" in lines[i].upper() and "PROJECT GUTENBERG" in lines[i].upper():
58
+ end_idx = i
59
+ break
60
+
61
+ text = "\n".join(lines[start_idx:end_idx])
62
+
63
+ # Basic cleanup
64
+ text = text.strip()
65
+ text = " ".join(text.split()) # normalize whitespace
66
+
67
+ return text
68
+
69
+
70
+ def chunk_text(text: str, chunk_size: int = MAX_CHUNK_LENGTH) -> list[str]:
71
+ """Split text into chunks for analysis."""
72
+ words = text.split()
73
+ chunks = []
74
+ current_chunk = []
75
+ current_length = 0
76
+
77
+ for word in words:
78
+ current_chunk.append(word)
79
+ current_length += len(word) + 1 # +1 for space
80
+
81
+ if current_length >= chunk_size:
82
+ chunks.append(" ".join(current_chunk))
83
+ current_chunk = []
84
+ current_length = 0
85
+
86
+ if current_chunk:
87
+ chunks.append(" ".join(current_chunk))
88
+
89
+ return chunks
90
+
91
+
92
+ def process_book(book_path: Path, pipeline) -> dict:
93
+ """Analyze a single book and return metadata."""
94
+ logger.info(f"Processing {book_path.name}...")
95
+
96
+ # Read and clean
97
+ try:
98
+ text = book_path.read_text(encoding="utf-8", errors="ignore")
99
+ except Exception as exc:
100
+ logger.error(f"Failed to read {book_path.name}: {exc}")
101
+ return {}
102
+
103
+ text = clean_text(text)
104
+
105
+ if not text or len(text) < 100:
106
+ logger.warning(f"Skipping {book_path.name} - insufficient content")
107
+ return {}
108
+
109
+ # Chunk and sample
110
+ chunks = chunk_text(text)
111
+ sample_chunks = chunks[: min(MAX_CHUNKS, len(chunks))]
112
+
113
+ logger.info(f" Analyzing {len(sample_chunks)} chunks (of {len(chunks)} total)...")
114
+
115
+ # Run inference on chunks
116
+ try:
117
+ topics = pipeline.predict_topics(sample_chunks)
118
+ emotions = pipeline.predict_emotions(sample_chunks, threshold=0.3)
119
+ summaries = pipeline.summarize(sample_chunks, max_length=64)
120
+
121
+ # Aggregate results
122
+ # Topic: most common prediction
123
+ topic_counts: dict[str, int] = {}
124
+ for t in topics:
125
+ topic_counts[t.label] = topic_counts.get(t.label, 0) + 1
126
+ dominant_topic = max(topic_counts.items(), key=lambda x: x[1])[0]
127
+
128
+ # Emotion: aggregate top emotions
129
+ all_emotions: dict[str, list[float]] = {}
130
+ for emotion in emotions:
131
+ for label, score in zip(emotion.labels, emotion.scores, strict=False):
132
+ if label not in all_emotions:
133
+ all_emotions[label] = []
134
+ all_emotions[label].append(score)
135
+
136
+ # Average scores and take top 3
137
+ emotion_scores = {
138
+ label: sum(scores) / len(scores) for label, scores in all_emotions.items()
139
+ }
140
+ top_emotions = sorted(emotion_scores.items(), key=lambda x: x[1], reverse=True)[:3]
141
+
142
+ # Summary: combine first few chunk summaries
143
+ combined_summary = " ".join(summaries[:3])
144
+
145
+ result: dict[str, object] = {
146
+ "title": book_path.stem.replace("_", " ").title(),
147
+ "filename": book_path.name,
148
+ "topic": dominant_topic,
149
+ "emotions": [{"label": label, "score": float(score)} for label, score in top_emotions],
150
+ "summary": combined_summary,
151
+ "word_count": len(text.split()),
152
+ "chunks_analyzed": len(sample_chunks),
153
+ }
154
+
155
+ logger.info(
156
+ f" ✓ {result['title']}: {result['topic']} | "
157
+ f"{', '.join(str(e['label']) for e in result['emotions'][:2] if isinstance(e, dict))}" # type: ignore[index]
158
+ )
159
+
160
+ return result
161
+
162
+ except Exception as exc:
163
+ logger.error(f"Analysis failed for {book_path.name}: {exc}", exc_info=True)
164
+ return {}
165
+
166
+
167
+ # --------------- Main ---------------
168
+
169
+
170
+ def main():
171
+ """Process all books and save library."""
172
+ logger.info("Loading inference pipeline...")
173
+
174
+ pipeline, label_metadata = create_inference_pipeline(
175
+ tokenizer_dir="artifacts/hf_tokenizer/",
176
+ checkpoint_path="checkpoints/best.pt",
177
+ labels_path="artifacts/labels.json",
178
+ )
179
+
180
+ logger.info("Finding books...")
181
+ book_files = sorted(BOOKS_DIR.glob("*.txt"))
182
+
183
+ if not book_files:
184
+ logger.error(f"No books found in {BOOKS_DIR}")
185
+ return
186
+
187
+ logger.info(f"Found {len(book_files)} books")
188
+
189
+ # Process each book
190
+ library = []
191
+ for book_path in book_files:
192
+ result = process_book(book_path, pipeline)
193
+ if result:
194
+ library.append(result)
195
+
196
+ # Save results
197
+ OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
198
+ with open(OUTPUT_PATH, "w") as f:
199
+ json.dump(
200
+ {
201
+ "books": library,
202
+ "metadata": {
203
+ "total_books": len(library),
204
+ "chunk_size": MAX_CHUNK_LENGTH,
205
+ "chunks_per_book": MAX_CHUNKS,
206
+ },
207
+ },
208
+ f,
209
+ indent=2,
210
+ )
211
+
212
+ logger.info(f"\n✓ Library saved to {OUTPUT_PATH}")
213
+ logger.info(f" Processed {len(library)} books")
214
+
215
+ # Print summary
216
+ print("\n" + "=" * 60)
217
+ print("BOOK LIBRARY SUMMARY")
218
+ print("=" * 60)
219
+
220
+ for book in library:
221
+ print(f"\n📚 {book['title']}")
222
+ print(f" Topic: {book['topic']}")
223
+ emotions_str = ", ".join(f"{e['label']} ({e['score']:.0%})" for e in book["emotions"])
224
+ print(f" Emotions: {emotions_str}")
225
+ print(f" Summary: {book['summary'][:100]}...")
226
+
227
+ print("\n" + "=" * 60)
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
scripts/train.py CHANGED
@@ -13,6 +13,7 @@ from __future__ import annotations
13
  import json
14
  import logging
15
  import os
 
16
  import sys
17
  import time
18
  import warnings
@@ -51,7 +52,7 @@ from src.data.tokenization import Tokenizer, TokenizerConfig
51
  from src.models.factory import ModelConfig, build_multitask_model
52
  from src.training.trainer import Trainer, TrainerConfig
53
  from src.training.utils import set_seed
54
- from src.utils.io import save_state
55
  from src.utils.labels import LabelMetadata, save_label_metadata
56
 
57
  # --------------- Data Loading ---------------
@@ -93,12 +94,13 @@ def limit_samples(splits: Dict[str, list], cfg: DictConfig) -> None:
93
 
94
 
95
  def compile_model(model: torch.nn.Module) -> torch.nn.Module:
96
- """Compile model with inductor backend (default mode, no CUDA graphs)."""
 
97
  from src.training.safe_compile import apply_safe_config, compile_model_safe
98
 
99
  # Apply safe configuration first
100
  apply_safe_config()
101
- # Compile with default mode (inductor without CUDA graphs)
102
  return compile_model_safe(model, mode="default")
103
 
104
 
@@ -148,10 +150,12 @@ def main(cfg: DictConfig) -> None:
148
  # --------------- Tokenizer & Datasets ---------------
149
 
150
  tok_cfg = data_cfg.get("tokenizer", {})
 
 
151
  tokenizer = Tokenizer(
152
  TokenizerConfig(
153
  pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
154
- max_length=int(tok_cfg.get("max_length", 512)),
155
  lower=bool(tok_cfg.get("lower", False)),
156
  )
157
  )
@@ -238,6 +242,7 @@ def main(cfg: DictConfig) -> None:
238
  device = torch.device(cfg.device)
239
  model_cfg = ModelConfig(
240
  d_model=cfg.model.d_model,
 
241
  num_encoder_layers=cfg.model.num_encoder_layers,
242
  num_decoder_layers=cfg.model.num_decoder_layers,
243
  num_attention_heads=cfg.model.num_attention_heads,
@@ -255,12 +260,41 @@ def main(cfg: DictConfig) -> None:
255
  config=model_cfg,
256
  ).to(device)
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Compile encoder/decoder for faster training (skip heads - small overhead)
259
- if model.encoder is not None:
 
 
260
  from src.models.encoder import TransformerEncoder
261
 
262
  model.encoder = cast(TransformerEncoder, compile_model(model.encoder))
263
- if model.decoder is not None:
264
  from src.models.decoder import TransformerDecoder
265
 
266
  model.decoder = cast(TransformerDecoder, compile_model(model.decoder))
@@ -268,21 +302,30 @@ def main(cfg: DictConfig) -> None:
268
  # --------------- Optimizer & Trainer ---------------
269
 
270
  opt_cfg = cfg.training.get("optimizer", {})
 
271
  optimizer = torch.optim.AdamW(
272
  model.parameters(),
273
  lr=float(opt_cfg.get("lr", 3e-5)),
274
  weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
275
  )
276
 
 
 
 
 
 
 
277
  trainer = Trainer(
278
  model=model,
279
  optimizer=optimizer,
280
  config=TrainerConfig(
281
- max_epochs=int(trainer_cfg.get("max_epochs", 1)),
282
  gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
283
  task_weights=trainer_cfg.get("task_weights"),
284
  label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
285
  gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
 
 
286
  ),
287
  device=device,
288
  tokenizer=tokenizer,
@@ -298,7 +341,12 @@ def main(cfg: DictConfig) -> None:
298
  save_state(model, str(path))
299
 
300
  print("\nStarting training...")
301
- history = trainer.fit(train_loaders, val_loaders, checkpoint_callback=save_checkpoint)
 
 
 
 
 
302
 
303
  # --------------- Save Outputs ---------------
304
 
 
13
  import json
14
  import logging
15
  import os
16
+ import re
17
  import sys
18
  import time
19
  import warnings
 
52
  from src.models.factory import ModelConfig, build_multitask_model
53
  from src.training.trainer import Trainer, TrainerConfig
54
  from src.training.utils import set_seed
55
+ from src.utils.io import load_state, save_state
56
  from src.utils.labels import LabelMetadata, save_label_metadata
57
 
58
  # --------------- Data Loading ---------------
 
94
 
95
 
96
  def compile_model(model: torch.nn.Module) -> torch.nn.Module:
97
+ """Compile model with inductor backend (optimized for speed)."""
98
+ print(f" -> Enabling torch.compile for {model.__class__.__name__}...")
99
  from src.training.safe_compile import apply_safe_config, compile_model_safe
100
 
101
  # Apply safe configuration first
102
  apply_safe_config()
103
+ # Compile with default mode (inductor) - most stable
104
  return compile_model_safe(model, mode="default")
105
 
106
 
 
150
  # --------------- Tokenizer & Datasets ---------------
151
 
152
  tok_cfg = data_cfg.get("tokenizer", {})
153
+ # Allow training overrides for max_length to run shorter dev sweeps
154
+ override_max_len = cfg.training.get("tokenizer_max_length")
155
  tokenizer = Tokenizer(
156
  TokenizerConfig(
157
  pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"),
158
+ max_length=int(override_max_len or tok_cfg.get("max_length", 512)),
159
  lower=bool(tok_cfg.get("lower", False)),
160
  )
161
  )
 
242
  device = torch.device(cfg.device)
243
  model_cfg = ModelConfig(
244
  d_model=cfg.model.d_model,
245
+ vocab_size=getattr(cfg.model, "vocab_size", None), # Override tokenizer vocab if specified
246
  num_encoder_layers=cfg.model.num_encoder_layers,
247
  num_decoder_layers=cfg.model.num_decoder_layers,
248
  num_attention_heads=cfg.model.num_attention_heads,
 
260
  config=model_cfg,
261
  ).to(device)
262
 
263
+ # If Training Crashes: Resume from checkpoint if provided (load before compile to avoid key mismatches)
264
+ start_epoch = 1
265
+ resume_path = cfg.get("resume_from")
266
+ if resume_path:
267
+ ckpt_path = Path(resume_path)
268
+ if ckpt_path.exists():
269
+ print(f"\n↩Resuming from checkpoint: {ckpt_path}")
270
+ load_state(model, str(ckpt_path))
271
+ # Parse epoch number robustly from filename (e.g., epoch_5.pt)
272
+ epoch_num = None
273
+ try:
274
+ # Prefer stem (no suffix); fallback to any digit sequence in name
275
+ digits = re.findall(r"\d+", ckpt_path.stem)
276
+ if digits:
277
+ epoch_num = int(digits[-1])
278
+ except Exception:
279
+ epoch_num = None
280
+
281
+ if epoch_num is not None:
282
+ start_epoch = epoch_num + 1
283
+ print(f" -> Starting from epoch {start_epoch}")
284
+ else:
285
+ print(" -> Could not parse epoch number; starting from epoch 1")
286
+ start_epoch = 1
287
+ else:
288
+ print(f"⚠ Resume checkpoint not found: {ckpt_path}. Starting from scratch.")
289
+
290
  # Compile encoder/decoder for faster training (skip heads - small overhead)
291
+ compile_encoder = bool(cfg.training.get("compile_encoder", True))
292
+ compile_decoder = bool(cfg.training.get("compile_decoder", True))
293
+ if compile_encoder and model.encoder is not None:
294
  from src.models.encoder import TransformerEncoder
295
 
296
  model.encoder = cast(TransformerEncoder, compile_model(model.encoder))
297
+ if compile_decoder and model.decoder is not None:
298
  from src.models.decoder import TransformerDecoder
299
 
300
  model.decoder = cast(TransformerDecoder, compile_model(model.decoder))
 
302
  # --------------- Optimizer & Trainer ---------------
303
 
304
  opt_cfg = cfg.training.get("optimizer", {})
305
+ sched_cfg = cfg.training.get("scheduler", {})
306
  optimizer = torch.optim.AdamW(
307
  model.parameters(),
308
  lr=float(opt_cfg.get("lr", 3e-5)),
309
  weight_decay=float(opt_cfg.get("weight_decay", 0.01)),
310
  )
311
 
312
+ # Clamp start_epoch to max_epochs to avoid empty loop
313
+ max_epochs = int(trainer_cfg.get("max_epochs", 1))
314
+ if start_epoch > max_epochs:
315
+ print(f"⚠ resume_from points past max_epochs ({max_epochs}); nothing to train. Setting start_epoch to {max_epochs}")
316
+ start_epoch = max_epochs
317
+
318
  trainer = Trainer(
319
  model=model,
320
  optimizer=optimizer,
321
  config=TrainerConfig(
322
+ max_epochs=max_epochs,
323
  gradient_clip_norm=float(trainer_cfg.get("gradient_clip_norm", 1.0)),
324
  task_weights=trainer_cfg.get("task_weights"),
325
  label_smoothing=float(trainer_cfg.get("label_smoothing", 0.0)),
326
  gradient_accumulation_steps=int(trainer_cfg.get("gradient_accumulation_steps", 1)),
327
+ scheduler_type=str(sched_cfg.get("name", "constant")),
328
+ warmup_steps=int(sched_cfg.get("warmup_steps", 0)),
329
  ),
330
  device=device,
331
  tokenizer=tokenizer,
 
341
  save_state(model, str(path))
342
 
343
  print("\nStarting training...")
344
+ history = trainer.fit(
345
+ train_loaders,
346
+ val_loaders,
347
+ checkpoint_callback=save_checkpoint,
348
+ start_epoch=start_epoch,
349
+ )
350
 
351
  # --------------- Save Outputs ---------------
352
 
scripts/visualize_training.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualize training metrics from MLflow runs.
3
+
4
+ Generates plots showing:
5
+ - Loss curves (training/validation)
6
+ - Task-specific metrics over time
7
+ - Learning rate schedule
8
+ - Training speed analysis
9
+
10
+ Author: Oliver Perrin
11
+ Date: December 2025
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ import matplotlib.pyplot as plt
21
+ import mlflow
22
+ import mlflow.tracking
23
+ import seaborn as sns
24
+
25
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
26
+ if str(PROJECT_ROOT) not in sys.path:
27
+ sys.path.insert(0, str(PROJECT_ROOT))
28
+
29
+ from src.utils.logging import configure_logging, get_logger
30
+
31
+ configure_logging()
32
+ logger = get_logger(__name__)
33
+
34
+ # Configure plotting style
35
+ sns.set_style("whitegrid")
36
+ plt.rcParams["figure.figsize"] = (12, 8)
37
+ plt.rcParams["figure.dpi"] = 100
38
+
39
+ OUTPUTS_DIR = PROJECT_ROOT / "outputs"
40
+ MLRUNS_DIR = PROJECT_ROOT / "mlruns"
41
+
42
+
43
+ def load_training_history() -> dict[str, object] | None:
44
+ """Load training history from JSON if available."""
45
+ history_path = OUTPUTS_DIR / "training_history.json"
46
+ if history_path.exists():
47
+ with open(history_path) as f:
48
+ data: dict[str, object] = json.load(f)
49
+ return data
50
+ return None
51
+
52
+
53
+ def get_latest_run():
54
+ """Get the most recent MLflow run."""
55
+ mlflow.set_tracking_uri(f"file://{MLRUNS_DIR}")
56
+ client = mlflow.tracking.MlflowClient()
57
+
58
+ # Get the experiment (LexiMind)
59
+ experiment = client.get_experiment_by_name("LexiMind")
60
+ if not experiment:
61
+ logger.error("No 'LexiMind' experiment found")
62
+ return None
63
+
64
+ # Get all runs, sorted by start time
65
+ runs = client.search_runs(
66
+ experiment_ids=[experiment.experiment_id],
67
+ order_by=["start_time DESC"],
68
+ max_results=1,
69
+ )
70
+
71
+ if not runs:
72
+ logger.error("No runs found in experiment")
73
+ return None
74
+
75
+ return runs[0]
76
+
77
+
78
+ def plot_loss_curves(run):
79
+ """Plot training and validation loss over time."""
80
+ client = mlflow.tracking.MlflowClient()
81
+
82
+ # Get metrics
83
+ train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
84
+ val_loss = client.get_metric_history(run.info.run_id, "val_total_loss")
85
+
86
+ fig, ax = plt.subplots(figsize=(12, 6))
87
+
88
+ if not train_loss:
89
+ # Create placeholder plot
90
+ ax.text(
91
+ 0.5,
92
+ 0.5,
93
+ "No training data yet\n\nWaiting for first epoch to complete...",
94
+ ha="center",
95
+ va="center",
96
+ fontsize=14,
97
+ color="gray",
98
+ )
99
+ ax.set_xlim(0, 1)
100
+ ax.set_ylim(0, 1)
101
+ else:
102
+ # Extract steps and values
103
+ train_steps = [m.step for m in train_loss]
104
+ train_values = [m.value for m in train_loss]
105
+
106
+ ax.plot(train_steps, train_values, label="Training Loss", linewidth=2, alpha=0.8)
107
+
108
+ if val_loss:
109
+ val_steps = [m.step for m in val_loss]
110
+ val_values = [m.value for m in val_loss]
111
+ ax.plot(val_steps, val_values, label="Validation Loss", linewidth=2, alpha=0.8)
112
+
113
+ ax.legend(fontsize=11)
114
+
115
+ ax.set_xlabel("Epoch", fontsize=12)
116
+ ax.set_ylabel("Loss", fontsize=12)
117
+ ax.set_title("Training Progress: Total Loss", fontsize=14, fontweight="bold")
118
+ ax.grid(True, alpha=0.3)
119
+
120
+ plt.tight_layout()
121
+ output_path = OUTPUTS_DIR / "training_loss_curve.png"
122
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
123
+ logger.info(f"✓ Saved loss curve to {output_path}")
124
+ plt.close()
125
+
126
+
127
+ def plot_task_metrics(run):
128
+ """Plot metrics for each task."""
129
+ client = mlflow.tracking.MlflowClient()
130
+
131
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
132
+ fig.suptitle("Task-Specific Training Metrics", fontsize=16, fontweight="bold")
133
+
134
+ # Summarization
135
+ ax = axes[0, 0]
136
+ train_sum = client.get_metric_history(run.info.run_id, "train_summarization_loss")
137
+ val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")
138
+
139
+ if train_sum:
140
+ ax.plot(
141
+ [m.step for m in train_sum], [m.value for m in train_sum], label="Train", linewidth=2
142
+ )
143
+ if val_sum:
144
+ ax.plot([m.step for m in val_sum], [m.value for m in val_sum], label="Val", linewidth=2)
145
+ ax.set_title("Summarization Loss", fontweight="bold")
146
+ ax.set_xlabel("Epoch")
147
+ ax.set_ylabel("Loss")
148
+ ax.legend()
149
+ ax.grid(True, alpha=0.3)
150
+
151
+ # Emotion
152
+ ax = axes[0, 1]
153
+ train_emo = client.get_metric_history(run.info.run_id, "train_emotion_loss")
154
+ val_emo = client.get_metric_history(run.info.run_id, "val_emotion_loss")
155
+ train_f1 = client.get_metric_history(run.info.run_id, "train_emotion_f1")
156
+ val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")
157
+
158
+ if train_emo:
159
+ ax.plot(
160
+ [m.step for m in train_emo],
161
+ [m.value for m in train_emo],
162
+ label="Train Loss",
163
+ linewidth=2,
164
+ )
165
+ if val_emo:
166
+ ax.plot(
167
+ [m.step for m in val_emo], [m.value for m in val_emo], label="Val Loss", linewidth=2
168
+ )
169
+
170
+ ax2 = ax.twinx()
171
+ if train_f1:
172
+ ax2.plot(
173
+ [m.step for m in train_f1],
174
+ [m.value for m in train_f1],
175
+ label="Train F1",
176
+ linewidth=2,
177
+ linestyle="--",
178
+ alpha=0.7,
179
+ )
180
+ if val_f1:
181
+ ax2.plot(
182
+ [m.step for m in val_f1],
183
+ [m.value for m in val_f1],
184
+ label="Val F1",
185
+ linewidth=2,
186
+ linestyle="--",
187
+ alpha=0.7,
188
+ )
189
+
190
+ ax.set_title("Emotion Detection", fontweight="bold")
191
+ ax.set_xlabel("Epoch")
192
+ ax.set_ylabel("Loss")
193
+ ax2.set_ylabel("F1 Score")
194
+ ax.legend(loc="upper left")
195
+ ax2.legend(loc="upper right")
196
+ ax.grid(True, alpha=0.3)
197
+
198
+ # Topic
199
+ ax = axes[1, 0]
200
+ train_topic = client.get_metric_history(run.info.run_id, "train_topic_loss")
201
+ val_topic = client.get_metric_history(run.info.run_id, "val_topic_loss")
202
+ train_acc = client.get_metric_history(run.info.run_id, "train_topic_accuracy")
203
+ val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")
204
+
205
+ if train_topic:
206
+ ax.plot(
207
+ [m.step for m in train_topic],
208
+ [m.value for m in train_topic],
209
+ label="Train Loss",
210
+ linewidth=2,
211
+ )
212
+ if val_topic:
213
+ ax.plot(
214
+ [m.step for m in val_topic], [m.value for m in val_topic], label="Val Loss", linewidth=2
215
+ )
216
+
217
+ ax2 = ax.twinx()
218
+ if train_acc:
219
+ ax2.plot(
220
+ [m.step for m in train_acc],
221
+ [m.value for m in train_acc],
222
+ label="Train Acc",
223
+ linewidth=2,
224
+ linestyle="--",
225
+ alpha=0.7,
226
+ )
227
+ if val_acc:
228
+ ax2.plot(
229
+ [m.step for m in val_acc],
230
+ [m.value for m in val_acc],
231
+ label="Val Acc",
232
+ linewidth=2,
233
+ linestyle="--",
234
+ alpha=0.7,
235
+ )
236
+
237
+ ax.set_title("Topic Classification", fontweight="bold")
238
+ ax.set_xlabel("Epoch")
239
+ ax.set_ylabel("Loss")
240
+ ax2.set_ylabel("Accuracy")
241
+ ax.legend(loc="upper left")
242
+ ax2.legend(loc="upper right")
243
+ ax.grid(True, alpha=0.3)
244
+
245
+ # Summary statistics
246
+ ax = axes[1, 1]
247
+ ax.axis("off")
248
+
249
+ # Get final metrics
250
+ summary_text = "Final Metrics (Last Epoch)\n" + "=" * 35 + "\n\n"
251
+
252
+ if val_topic and val_acc:
253
+ summary_text += f"Topic Accuracy: {val_acc[-1].value:.1%}\n"
254
+ if val_emo and val_f1:
255
+ summary_text += f"Emotion F1: {val_f1[-1].value:.1%}\n"
256
+ if val_sum:
257
+ summary_text += f"Summarization Loss: {val_sum[-1].value:.3f}\n"
258
+
259
+ ax.text(0.1, 0.5, summary_text, fontsize=12, family="monospace", verticalalignment="center")
260
+
261
+ plt.tight_layout()
262
+ output_path = OUTPUTS_DIR / "task_metrics.png"
263
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
264
+ logger.info(f"✓ Saved task metrics to {output_path}")
265
+ plt.close()
266
+
267
+
268
+ def plot_learning_rate(run):
269
+ """Plot learning rate schedule if available."""
270
+ client = mlflow.tracking.MlflowClient()
271
+ lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")
272
+
273
+ fig, ax = plt.subplots(figsize=(12, 5))
274
+
275
+ if not lr_metrics:
276
+ # Create placeholder
277
+ ax.text(
278
+ 0.5,
279
+ 0.5,
280
+ "No learning rate data yet\n\n(Will be logged in future training runs)",
281
+ ha="center",
282
+ va="center",
283
+ fontsize=14,
284
+ color="gray",
285
+ )
286
+ ax.set_xlim(0, 1)
287
+ ax.set_ylim(0, 1)
288
+ else:
289
+ steps = [m.step for m in lr_metrics]
290
+ values = [m.value for m in lr_metrics]
291
+
292
+ ax.plot(steps, values, linewidth=2, color="darkblue")
293
+
294
+ # Mark warmup region
295
+ warmup_steps = 1000 # From config
296
+ if warmup_steps < max(steps):
297
+ ax.axvline(warmup_steps, color="red", linestyle="--", alpha=0.5, label="Warmup End")
298
+ ax.legend()
299
+
300
+ ax.set_xlabel("Step", fontsize=12)
301
+ ax.set_ylabel("Learning Rate", fontsize=12)
302
+ ax.set_title("Learning Rate Schedule (Cosine with Warmup)", fontsize=14, fontweight="bold")
303
+ ax.grid(True, alpha=0.3)
304
+
305
+ plt.tight_layout()
306
+ output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
307
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
308
+ logger.info(f"✓ Saved LR schedule to {output_path}")
309
+ plt.close()
310
+
311
+
312
+ def main():
313
+ """Generate all training visualizations."""
314
+ logger.info("Loading MLflow data...")
315
+
316
+ run = get_latest_run()
317
+ if not run:
318
+ logger.error("No training run found. Make sure training has started.")
319
+ return
320
+
321
+ logger.info(f"Analyzing run: {run.info.run_id}")
322
+
323
+ OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
324
+
325
+ logger.info("Generating visualizations...")
326
+
327
+ plot_loss_curves(run)
328
+ plot_task_metrics(run)
329
+ plot_learning_rate(run)
330
+
331
+ logger.info("\n" + "=" * 60)
332
+ logger.info("✓ All visualizations saved to outputs/")
333
+ logger.info("=" * 60)
334
+ logger.info(" - training_loss_curve.png")
335
+ logger.info(" - task_metrics.png")
336
+ logger.info(" - learning_rate_schedule.png")
337
+ logger.info("=" * 60)
338
+
339
+
340
+ if __name__ == "__main__":
341
+ main()
src/data/dataloader.py CHANGED
@@ -48,13 +48,16 @@ class SummarizationCollator:
48
  src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
49
  tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
50
 
51
- # Shift targets: tgt_ids = [BOS, A, B], labels = [A, B, EOS]
52
  ids = tgt_enc["input_ids"]
53
  mask = tgt_enc["attention_mask"]
54
 
55
- tgt_ids = ids[:, :-1]
56
- labels = ids[:, 1:].clone()
57
- labels[mask[:, 1:] == 0] = -100 # Mask padding in loss
 
 
 
 
58
 
59
  return {
60
  "src_ids": src_enc["input_ids"],
 
48
  src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
49
  tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
50
 
 
51
  ids = tgt_enc["input_ids"]
52
  mask = tgt_enc["attention_mask"]
53
 
54
+ # Create labels for loss: mask padding with -100
55
+ labels = ids.clone()
56
+ labels[mask == 0] = -100
57
+
58
+ # Create decoder inputs from original ids (no -100)
59
+ # prepare_decoder_inputs shifts right and adds BOS
60
+ tgt_ids = self.tokenizer.prepare_decoder_inputs(ids)
61
 
62
  return {
63
  "src_ids": src_enc["input_ids"],
src/inference/pipeline.py CHANGED
@@ -69,6 +69,7 @@ class InferenceConfig:
69
 
70
  summary_max_length: int = 128
71
  summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
 
72
  emotion_threshold: float = 0.5
73
  device: str | None = None
74
 
@@ -164,6 +165,8 @@ class InferencePipeline:
164
 
165
  # Decode and format summaries
166
  raw_summaries = self.tokenizer.decode_batch(generated.tolist())
 
 
167
  return [_format_summary(s) for s in raw_summaries]
168
 
169
  # --------------- Emotion ---------------
 
69
 
70
  summary_max_length: int = 128
71
  summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
72
+ summary_formatting: bool = True # Apply text cleanup/formatting to generated summaries
73
  emotion_threshold: float = 0.5
74
  device: str | None = None
75
 
 
165
 
166
  # Decode and format summaries
167
  raw_summaries = self.tokenizer.decode_batch(generated.tolist())
168
+ if not self.config.summary_formatting:
169
+ return raw_summaries
170
  return [_format_summary(s) for s in raw_summaries]
171
 
172
  # --------------- Emotion ---------------
src/inference/postprocessing.py DELETED
@@ -1,14 +0,0 @@
1
- """
2
- Output postprocessing utilities for LexiMind.
3
-
4
- Provides text cleaning helpers for model outputs.
5
-
6
- Author: Oliver Perrin
7
- Date: December 2025
8
- """
9
-
10
- from typing import List
11
-
12
-
13
- def strip_whitespace(texts: List[str]) -> List[str]:
14
- return [text.strip() for text in texts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/decoder.py CHANGED
@@ -18,10 +18,12 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast
18
 
19
  import torch
20
  import torch.nn as nn
 
21
 
22
  from .attention import MultiHeadAttention, T5RelativePositionBias
23
  from .feedforward import FeedForward
24
  from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
 
25
 
26
 
27
  def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
@@ -77,9 +79,9 @@ class TransformerDecoderLayer(nn.Module):
77
  quantization=quantization,
78
  )
79
 
80
- self.norm1 = nn.RMSNorm(d_model)
81
- self.norm2 = nn.RMSNorm(d_model)
82
- self.norm3 = nn.RMSNorm(d_model)
83
 
84
  self.dropout1 = nn.Dropout(dropout)
85
  self.dropout2 = nn.Dropout(dropout)
@@ -189,6 +191,7 @@ class TransformerDecoder(nn.Module):
189
  use_learned_pos_enc: bool = False,
190
  activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
191
  use_relative_position_bias: bool = False, # T5-style relative position bias
 
192
  ):
193
  super().__init__()
194
  self.vocab_size = vocab_size
@@ -196,8 +199,10 @@ class TransformerDecoder(nn.Module):
196
  self.pad_token_id = pad_token_id
197
  self.num_heads = num_heads
198
  self.use_relative_position_bias = use_relative_position_bias
 
199
 
200
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
 
201
 
202
  # Positional encoding (disabled when using relative position bias for T5)
203
  self.self_relative_position_bias: Optional[T5RelativePositionBias] = None
@@ -238,8 +243,8 @@ class TransformerDecoder(nn.Module):
238
  ]
239
  )
240
 
241
- self.final_norm = nn.RMSNorm(d_model)
242
- self.output_projection = nn.Linear(d_model, vocab_size)
243
  self.input_dropout = nn.Dropout(dropout)
244
 
245
  def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -252,6 +257,18 @@ class TransformerDecoder(nn.Module):
252
  """
253
  assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
254
  pad_mask = input_ids != self.pad_token_id # (B, T)
 
 
 
 
 
 
 
 
 
 
 
 
255
  attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, T, T)
256
  return attn_mask
257
 
@@ -263,7 +280,7 @@ class TransformerDecoder(nn.Module):
263
  memory_mask: Optional[torch.Tensor] = None,
264
  collect_attn: bool = False,
265
  skip_padding_mask: bool = False, # Set True during generation to avoid masking start token
266
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
267
  """
268
  Args:
269
  inputs: (B, T) token ids or (B, T, d_model) embeddings
@@ -304,6 +321,12 @@ class TransformerDecoder(nn.Module):
304
  else:
305
  # Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
306
  tgt_mask = tgt_mask.to(dtype=torch.bool, device=x.device)
 
 
 
 
 
 
307
 
308
  # Normalize memory_mask dtype/device and expand simple shapes
309
  if memory_mask is not None:
@@ -313,7 +336,7 @@ class TransformerDecoder(nn.Module):
313
  elif memory_mask.dim() == 3: # (B, T, S) -> (B, 1, T, S)
314
  memory_mask = memory_mask.unsqueeze(1)
315
 
316
- attn_list: List[Dict[str, torch.Tensor]] = []
317
 
318
  # Compute relative position biases (T5-style)
319
  # Note: T5 uses relative position bias for self-attention but NOT for cross-attention
@@ -328,19 +351,37 @@ class TransformerDecoder(nn.Module):
328
 
329
  # Pass through decoder layers
330
  for layer in self.layers:
331
- x, attn = layer(
332
- x,
333
- memory,
334
- tgt_mask=tgt_mask,
335
- memory_mask=memory_mask,
336
- collect_attn=collect_attn,
337
- self_attn_position_bias=self_position_bias,
338
- cross_attn_position_bias=cross_position_bias,
339
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  if collect_attn:
341
  attn_list.append(attn)
342
 
343
  x = self.final_norm(x)
 
344
  logits = self.output_projection(x) # (B, T, vocab)
345
 
346
  if collect_attn:
 
18
 
19
  import torch
20
  import torch.nn as nn
21
+ from torch.utils.checkpoint import checkpoint
22
 
23
  from .attention import MultiHeadAttention, T5RelativePositionBias
24
  from .feedforward import FeedForward
25
  from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
26
+ from .t5_layer_norm import T5LayerNorm
27
 
28
 
29
  def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
 
79
  quantization=quantization,
80
  )
81
 
82
+ self.norm1 = T5LayerNorm(d_model)
83
+ self.norm2 = T5LayerNorm(d_model)
84
+ self.norm3 = T5LayerNorm(d_model)
85
 
86
  self.dropout1 = nn.Dropout(dropout)
87
  self.dropout2 = nn.Dropout(dropout)
 
191
  use_learned_pos_enc: bool = False,
192
  activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
193
  use_relative_position_bias: bool = False, # T5-style relative position bias
194
+ gradient_checkpointing: bool = False,
195
  ):
196
  super().__init__()
197
  self.vocab_size = vocab_size
 
199
  self.pad_token_id = pad_token_id
200
  self.num_heads = num_heads
201
  self.use_relative_position_bias = use_relative_position_bias
202
+ self.gradient_checkpointing = gradient_checkpointing
203
 
204
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
205
+ # Note: T5 does NOT scale logits (scaling factor removed)
206
 
207
  # Positional encoding (disabled when using relative position bias for T5)
208
  self.self_relative_position_bias: Optional[T5RelativePositionBias] = None
 
243
  ]
244
  )
245
 
246
+ self.final_norm = T5LayerNorm(d_model)
247
+ self.output_projection = nn.Linear(d_model, vocab_size, bias=False) # T5 has no bias
248
  self.input_dropout = nn.Dropout(dropout)
249
 
250
  def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
 
257
  """
258
  assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
259
  pad_mask = input_ids != self.pad_token_id # (B, T)
260
+
261
+ # Always allow attending to the first token (BOS), even if it is pad_token_id
262
+ # Avoid in-place mutation for better torch.compile compatibility
263
+ if pad_mask.size(1) > 0:
264
+ # Create a mask for the first column (B, 1)
265
+ first_col_mask = torch.zeros_like(pad_mask[:, :1], dtype=torch.bool)
266
+ first_col_mask[:] = True
267
+ # Combine: pad_mask OR (column==0)
268
+ # We can do this by creating a column index tensor
269
+ col_indices = torch.arange(pad_mask.size(1), device=pad_mask.device).unsqueeze(0)
270
+ pad_mask = pad_mask | (col_indices == 0)
271
+
272
  attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, T, T)
273
  return attn_mask
274
 
 
280
  memory_mask: Optional[torch.Tensor] = None,
281
  collect_attn: bool = False,
282
  skip_padding_mask: bool = False, # Set True during generation to avoid masking start token
283
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, Optional[torch.Tensor]]]]]:
284
  """
285
  Args:
286
  inputs: (B, T) token ids or (B, T, d_model) embeddings
 
321
  else:
322
  # Ensure boolean and device alignment; accept (B, T, T) or (B,1,T,T) or (1,1,T,T)
323
  tgt_mask = tgt_mask.to(dtype=torch.bool, device=x.device)
324
+ # If tgt_mask is just causal (T, T), expand it
325
+ if tgt_mask.dim() == 2:
326
+ tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)
327
+ elif tgt_mask.dim() == 3:
328
+ tgt_mask = tgt_mask.unsqueeze(1)
329
+
330
 
331
  # Normalize memory_mask dtype/device and expand simple shapes
332
  if memory_mask is not None:
 
336
  elif memory_mask.dim() == 3: # (B, T, S) -> (B, 1, T, S)
337
  memory_mask = memory_mask.unsqueeze(1)
338
 
339
+ attn_list: List[Dict[str, Optional[torch.Tensor]]] = []
340
 
341
  # Compute relative position biases (T5-style)
342
  # Note: T5 uses relative position bias for self-attention but NOT for cross-attention
 
351
 
352
  # Pass through decoder layers
353
  for layer in self.layers:
354
+ if self.gradient_checkpointing and self.training:
355
+ # Gradient checkpointing requires the inputs to require grad
356
+ def create_custom_forward(module):
357
+ def custom_forward(*inputs):
358
+ return module(*inputs, tgt_mask=tgt_mask, memory_mask=memory_mask, collect_attn=collect_attn, self_attn_position_bias=self_position_bias, cross_attn_position_bias=cross_position_bias)
359
+ return custom_forward
360
+
361
+ x, attn = cast(
362
+ Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]],
363
+ checkpoint(
364
+ create_custom_forward(layer),
365
+ x,
366
+ memory,
367
+ use_reentrant=False,
368
+ ),
369
+ )
370
+ else:
371
+ x, attn = layer(
372
+ x,
373
+ memory,
374
+ tgt_mask=tgt_mask,
375
+ memory_mask=memory_mask,
376
+ collect_attn=collect_attn,
377
+ self_attn_position_bias=self_position_bias,
378
+ cross_attn_position_bias=cross_position_bias,
379
+ )
380
  if collect_attn:
381
  attn_list.append(attn)
382
 
383
  x = self.final_norm(x)
384
+ # T5 does NOT scale logits - direct projection to vocabulary
385
  logits = self.output_projection(x) # (B, T, vocab)
386
 
387
  if collect_attn:
src/models/encoder.py CHANGED
@@ -13,15 +13,17 @@ Author: Oliver Perrin
13
  Date: 2025-10-23
14
  """
15
 
16
- from typing import List, Literal, Optional, Tuple, Union
17
 
18
  import torch
19
  import torch.nn as nn
 
20
 
21
  # Encoder implementation
22
  from .attention import MultiHeadAttention, T5RelativePositionBias
23
  from .feedforward import FeedForward
24
  from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
 
25
 
26
 
27
  class TransformerEncoderLayer(nn.Module):
@@ -65,8 +67,8 @@ class TransformerEncoderLayer(nn.Module):
65
  quantization=quantization,
66
  )
67
 
68
- self.norm1 = nn.RMSNorm(d_model)
69
- self.norm2 = nn.RMSNorm(d_model)
70
 
71
  self.dropout1 = nn.Dropout(dropout)
72
  self.dropout2 = nn.Dropout(dropout)
@@ -153,12 +155,14 @@ class TransformerEncoder(nn.Module):
153
  use_learned_pos_enc: bool = False,
154
  activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
155
  use_relative_position_bias: bool = False, # T5-style relative position bias
 
156
  ):
157
  super().__init__()
158
  self.vocab_size = vocab_size
159
  self.d_model = d_model
160
  self.pad_token_id = pad_token_id
161
  self.use_relative_position_bias = use_relative_position_bias
 
162
 
163
  # Token embedding (only used if forward receives token ids)
164
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
@@ -201,8 +205,8 @@ class TransformerEncoder(nn.Module):
201
  ]
202
  )
203
 
204
- # Final RMSNorm for Pre-LN stacks (recommended)
205
- self.final_norm = nn.RMSNorm(d_model)
206
 
207
  # Dropout applied after embedding + positional encoding (paper uses this)
208
  self.input_dropout = nn.Dropout(dropout)
@@ -282,7 +286,25 @@ class TransformerEncoder(nn.Module):
282
 
283
  # Pass through each encoder layer (optionally collect attn)
284
  for layer in self.layers:
285
- x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  if collect_attn:
287
  attn_weights_per_layer.append(attn)
288
 
 
13
  Date: 2025-10-23
14
  """
15
 
16
+ from typing import List, Literal, Optional, Tuple, Union, cast
17
 
18
  import torch
19
  import torch.nn as nn
20
+ from torch.utils.checkpoint import checkpoint
21
 
22
  # Encoder implementation
23
  from .attention import MultiHeadAttention, T5RelativePositionBias
24
  from .feedforward import FeedForward
25
  from .positional_encoding import LearnedPositionalEncoding, PositionalEncoding
26
+ from .t5_layer_norm import T5LayerNorm
27
 
28
 
29
  class TransformerEncoderLayer(nn.Module):
 
67
  quantization=quantization,
68
  )
69
 
70
+ self.norm1 = T5LayerNorm(d_model)
71
+ self.norm2 = T5LayerNorm(d_model)
72
 
73
  self.dropout1 = nn.Dropout(dropout)
74
  self.dropout2 = nn.Dropout(dropout)
 
155
  use_learned_pos_enc: bool = False,
156
  activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gated-gelu",
157
  use_relative_position_bias: bool = False, # T5-style relative position bias
158
+ gradient_checkpointing: bool = False,
159
  ):
160
  super().__init__()
161
  self.vocab_size = vocab_size
162
  self.d_model = d_model
163
  self.pad_token_id = pad_token_id
164
  self.use_relative_position_bias = use_relative_position_bias
165
+ self.gradient_checkpointing = gradient_checkpointing
166
 
167
  # Token embedding (only used if forward receives token ids)
168
  self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
 
205
  ]
206
  )
207
 
208
+ # Final T5LayerNorm for Pre-LN stacks
209
+ self.final_norm = T5LayerNorm(d_model)
210
 
211
  # Dropout applied after embedding + positional encoding (paper uses this)
212
  self.input_dropout = nn.Dropout(dropout)
 
286
 
287
  # Pass through each encoder layer (optionally collect attn)
288
  for layer in self.layers:
289
+ if self.gradient_checkpointing and self.training:
290
+ # Gradient checkpointing requires the inputs to require grad
291
+ # We use a lambda to pass keyword arguments
292
+ def create_custom_forward(module):
293
+ def custom_forward(*inputs):
294
+ return module(*inputs, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
295
+ return custom_forward
296
+
297
+ x, attn = cast(
298
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
299
+ checkpoint(
300
+ create_custom_forward(layer),
301
+ x,
302
+ use_reentrant=False,
303
+ ),
304
+ )
305
+ else:
306
+ x, attn = layer(x, mask=mask, collect_attn=collect_attn, position_bias=position_bias)
307
+
308
  if collect_attn:
309
  attn_weights_per_layer.append(attn)
310
 
src/models/factory.py CHANGED
@@ -14,15 +14,15 @@ from __future__ import annotations
14
 
15
  from dataclasses import dataclass
16
  from pathlib import Path
17
- from typing import Literal, Optional, cast
18
 
19
  import torch
20
  from transformers import T5ForConditionalGeneration
21
 
22
  from ..data.tokenization import Tokenizer
23
  from ..utils.config import load_yaml
24
- from .decoder import TransformerDecoder
25
- from .encoder import TransformerEncoder
26
  from .heads import ClassificationHead, LMHead
27
  from .multitask import MultiTaskModel
28
 
@@ -35,6 +35,7 @@ class ModelConfig:
35
  """Configuration describing the transformer architecture."""
36
 
37
  d_model: int = 768
 
38
  num_encoder_layers: int = 12
39
  num_decoder_layers: int = 12
40
  num_attention_heads: int = 12
@@ -50,6 +51,7 @@ class ModelConfig:
50
  use_relative_position_bias: bool = (
51
  False # T5-style relative position bias (use True for T5/FLAN-T5)
52
  )
 
53
 
54
  def __post_init__(self):
55
  if self.d_model % self.num_attention_heads != 0:
@@ -77,6 +79,7 @@ def load_model_config(path: Optional[str | Path]) -> ModelConfig:
77
  data = load_yaml(str(path)).data
78
  return ModelConfig(
79
  d_model=int(data.get("d_model", 512)),
 
80
  num_encoder_layers=int(data.get("num_encoder_layers", 6)),
81
  num_decoder_layers=int(data.get("num_decoder_layers", 6)),
82
  num_attention_heads=int(data.get("num_attention_heads", 8)),
@@ -88,6 +91,7 @@ def load_model_config(path: Optional[str | Path]) -> ModelConfig:
88
  use_learned_pos_enc=bool(data.get("use_learned_pos_enc", True)),
89
  activation=str(data.get("activation", "gelu")),
90
  use_relative_position_bias=bool(data.get("use_relative_position_bias", False)),
 
91
  )
92
 
93
 
@@ -107,11 +111,10 @@ def _load_pretrained_weights(
107
  -> We zero-initialize the bias terms
108
  """
109
  print(f"Loading pretrained weights from {model_name}...")
110
- t5 = T5ForConditionalGeneration.from_pretrained(model_name)
111
 
112
  # Load shared embeddings (T5 uses shared embeddings for encoder and decoder)
113
  # Note: T5's vocab is padded to multiple of 128 for efficiency (32100 -> 32128)
114
- # Our model uses the tokenizer's actual vocab size, so we only copy the valid tokens
115
  print("Transferring shared token embeddings...")
116
  shared_embeddings = t5.shared.weight.data
117
  our_vocab_size = encoder.embedding.weight.size(0)
@@ -124,6 +127,19 @@ def _load_pretrained_weights(
124
  print(f" Copying first {min_vocab} token embeddings...")
125
  encoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
126
  decoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  else:
128
  encoder.embedding.weight.data.copy_(shared_embeddings)
129
  decoder.embedding.weight.data.copy_(shared_embeddings)
@@ -136,11 +152,13 @@ def _load_pretrained_weights(
136
  print("Transferring encoder weights...")
137
  t5_encoder = t5.encoder
138
 
139
- for custom_layer, t5_layer in zip(encoder.layers, t5_encoder.block, strict=False):
140
- t5_self_attn = t5_layer.layer[0].SelfAttention
141
- t5_ffn = t5_layer.layer[1].DenseReluDense
142
- t5_norm1 = t5_layer.layer[0].layer_norm
143
- t5_norm2 = t5_layer.layer[1].layer_norm
 
 
144
 
145
  # Self-attention (T5 has no bias in attention projections)
146
  custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
@@ -190,7 +208,7 @@ def _load_pretrained_weights(
190
  if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
191
  print("Transferring encoder relative position bias...")
192
  t5_enc_rel_bias = (
193
- t5_encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.data
194
  )
195
  encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
196
 
@@ -198,13 +216,15 @@ def _load_pretrained_weights(
198
  print("Transferring decoder weights...")
199
  t5_decoder = t5.decoder
200
 
201
- for custom_layer, t5_layer in zip(decoder.layers, t5_decoder.block, strict=False):
202
- t5_self_attn = t5_layer.layer[0].SelfAttention
203
- t5_cross_attn = t5_layer.layer[1].EncDecAttention
204
- t5_ffn = t5_layer.layer[2].DenseReluDense
205
- t5_norm1 = t5_layer.layer[0].layer_norm
206
- t5_norm2 = t5_layer.layer[1].layer_norm
207
- t5_norm3 = t5_layer.layer[2].layer_norm
 
 
208
 
209
  # Self-attention
210
  custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
@@ -265,7 +285,7 @@ def _load_pretrained_weights(
265
  ):
266
  print("Transferring decoder self-attention relative position bias...")
267
  t5_dec_self_rel_bias = (
268
- t5_decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.data
269
  )
270
  decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
271
  t5_dec_self_rel_bias
@@ -278,7 +298,7 @@ def _load_pretrained_weights(
278
  print("Transferring decoder cross-attention relative position bias...")
279
  # Cross-attention relative position bias is in EncDecAttention of first block
280
  t5_dec_cross_rel_bias = (
281
- t5_decoder.block[0].layer[1].EncDecAttention.relative_attention_bias.weight.data
282
  )
283
  decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
284
  t5_dec_cross_rel_bias
@@ -367,9 +387,9 @@ def _load_llama_weights(
367
  num_layers = min(len(encoder.layers), len(llama.model.layers))
368
 
369
  for i in range(num_layers):
370
- llama_layer = llama.model.layers[i]
371
- enc_layer = encoder.layers[i]
372
- dec_layer = decoder.layers[i]
373
 
374
  # --- Self-Attention ---
375
  # Llama: q_proj, k_proj, v_proj, o_proj
@@ -460,15 +480,19 @@ def build_multitask_model(
460
  if hasattr(tokenizer, "config") and hasattr(tokenizer.config, "max_length"):
461
  max_len = tokenizer.config.max_length
462
  elif hasattr(tokenizer, "model_max_length"):
463
- max_len = tokenizer.model_max_length
464
  else:
465
  max_len = 512 # Default fallback
466
 
467
  # Cast activation to the literal type for mypy
468
  activation = cast(ActivationType, cfg.activation)
469
 
 
 
 
 
470
  encoder = TransformerEncoder(
471
- vocab_size=tokenizer.vocab_size,
472
  d_model=cfg.d_model,
473
  num_layers=cfg.num_encoder_layers,
474
  num_heads=cfg.num_attention_heads,
@@ -480,9 +504,10 @@ def build_multitask_model(
480
  use_learned_pos_enc=cfg.use_learned_pos_enc,
481
  activation=activation,
482
  use_relative_position_bias=cfg.use_relative_position_bias,
 
483
  )
484
  decoder = TransformerDecoder(
485
- vocab_size=tokenizer.vocab_size,
486
  d_model=cfg.d_model,
487
  num_layers=cfg.num_decoder_layers,
488
  num_heads=cfg.num_attention_heads,
@@ -494,6 +519,7 @@ def build_multitask_model(
494
  use_learned_pos_enc=cfg.use_learned_pos_enc,
495
  activation=activation,
496
  use_relative_position_bias=cfg.use_relative_position_bias,
 
497
  )
498
 
499
  # Load pretrained weights if requested (but allow override for inference)
@@ -513,12 +539,14 @@ def build_multitask_model(
513
  )
514
  _load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
515
 
 
 
 
 
516
  model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
517
  model.add_head(
518
  "summarization",
519
- LMHead(
520
- d_model=cfg.d_model, vocab_size=tokenizer.vocab_size, tie_embedding=decoder.embedding
521
- ),
522
  )
523
  model.add_head(
524
  "emotion",
 
14
 
15
  from dataclasses import dataclass
16
  from pathlib import Path
17
+ from typing import Any, Literal, Optional, cast
18
 
19
  import torch
20
  from transformers import T5ForConditionalGeneration
21
 
22
  from ..data.tokenization import Tokenizer
23
  from ..utils.config import load_yaml
24
+ from .decoder import TransformerDecoder, TransformerDecoderLayer
25
+ from .encoder import TransformerEncoder, TransformerEncoderLayer
26
  from .heads import ClassificationHead, LMHead
27
  from .multitask import MultiTaskModel
28
 
 
35
  """Configuration describing the transformer architecture."""
36
 
37
  d_model: int = 768
38
+ vocab_size: Optional[int] = None # Override tokenizer vocab size (e.g., 32128 for FLAN-T5)
39
  num_encoder_layers: int = 12
40
  num_decoder_layers: int = 12
41
  num_attention_heads: int = 12
 
51
  use_relative_position_bias: bool = (
52
  False # T5-style relative position bias (use True for T5/FLAN-T5)
53
  )
54
+ gradient_checkpointing: bool = False
55
 
56
  def __post_init__(self):
57
  if self.d_model % self.num_attention_heads != 0:
 
79
  data = load_yaml(str(path)).data
80
  return ModelConfig(
81
  d_model=int(data.get("d_model", 512)),
82
+ vocab_size=data.get("vocab_size", None), # Optional vocab size override
83
  num_encoder_layers=int(data.get("num_encoder_layers", 6)),
84
  num_decoder_layers=int(data.get("num_decoder_layers", 6)),
85
  num_attention_heads=int(data.get("num_attention_heads", 8)),
 
91
  use_learned_pos_enc=bool(data.get("use_learned_pos_enc", True)),
92
  activation=str(data.get("activation", "gelu")),
93
  use_relative_position_bias=bool(data.get("use_relative_position_bias", False)),
94
+ gradient_checkpointing=bool(data.get("gradient_checkpointing", False)),
95
  )
96
 
97
 
 
111
  -> We zero-initialize the bias terms
112
  """
113
  print(f"Loading pretrained weights from {model_name}...")
114
+ t5 = T5ForConditionalGeneration.from_pretrained(model_name) # type: ignore[attr-defined]
115
 
116
  # Load shared embeddings (T5 uses shared embeddings for encoder and decoder)
117
  # Note: T5's vocab is padded to multiple of 128 for efficiency (32100 -> 32128)
 
118
  print("Transferring shared token embeddings...")
119
  shared_embeddings = t5.shared.weight.data
120
  our_vocab_size = encoder.embedding.weight.size(0)
 
127
  print(f" Copying first {min_vocab} token embeddings...")
128
  encoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
129
  decoder.embedding.weight.data[:min_vocab].copy_(shared_embeddings[:min_vocab])
130
+
131
+ # Initialize any extra tokens (e.g., tokens 32100-32127) with small random values
132
+ if our_vocab_size > t5_vocab_size:
133
+ print(
134
+ f" Initializing {our_vocab_size - t5_vocab_size} extra padding tokens with small values..."
135
+ )
136
+ # Use small random initialization for stability (mean of existing embeddings ± small noise)
137
+ mean_emb = shared_embeddings.mean(dim=0, keepdim=True)
138
+ encoder.embedding.weight.data[t5_vocab_size:].normal_(mean=0.0, std=0.02)
139
+ encoder.embedding.weight.data[t5_vocab_size:] += mean_emb
140
+ decoder.embedding.weight.data[t5_vocab_size:].copy_(
141
+ encoder.embedding.weight.data[t5_vocab_size:]
142
+ )
143
  else:
144
  encoder.embedding.weight.data.copy_(shared_embeddings)
145
  decoder.embedding.weight.data.copy_(shared_embeddings)
 
152
  print("Transferring encoder weights...")
153
  t5_encoder = t5.encoder
154
 
155
+ for custom_layer_untyped, t5_layer in zip(encoder.layers, t5_encoder.block, strict=False):
156
+ custom_layer = cast(TransformerEncoderLayer, custom_layer_untyped)
157
+ t5_block = cast(Any, t5_layer)
158
+ t5_self_attn = t5_block.layer[0].SelfAttention
159
+ t5_ffn = t5_block.layer[1].DenseReluDense
160
+ t5_norm1 = t5_block.layer[0].layer_norm
161
+ t5_norm2 = t5_block.layer[1].layer_norm
162
 
163
  # Self-attention (T5 has no bias in attention projections)
164
  custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
 
208
  if hasattr(encoder, "relative_position_bias") and encoder.relative_position_bias is not None:
209
  print("Transferring encoder relative position bias...")
210
  t5_enc_rel_bias = (
211
+ cast(Any, t5_encoder.block[0]).layer[0].SelfAttention.relative_attention_bias.weight.data
212
  )
213
  encoder.relative_position_bias.relative_attention_bias.weight.data.copy_(t5_enc_rel_bias)
214
 
 
216
  print("Transferring decoder weights...")
217
  t5_decoder = t5.decoder
218
 
219
+ for custom_layer_untyped, t5_layer in zip(decoder.layers, t5_decoder.block, strict=False):
220
+ custom_layer = cast(TransformerDecoderLayer, custom_layer_untyped)
221
+ t5_block = cast(Any, t5_layer)
222
+ t5_self_attn = t5_block.layer[0].SelfAttention
223
+ t5_cross_attn = t5_block.layer[1].EncDecAttention
224
+ t5_ffn = t5_block.layer[2].DenseReluDense
225
+ t5_norm1 = t5_block.layer[0].layer_norm
226
+ t5_norm2 = t5_block.layer[1].layer_norm
227
+ t5_norm3 = t5_block.layer[2].layer_norm
228
 
229
  # Self-attention
230
  custom_layer.self_attn.W_Q.weight.data.copy_(t5_self_attn.q.weight.data)
 
285
  ):
286
  print("Transferring decoder self-attention relative position bias...")
287
  t5_dec_self_rel_bias = (
288
+ cast(Any, t5_decoder.block[0]).layer[0].SelfAttention.relative_attention_bias.weight.data
289
  )
290
  decoder.self_relative_position_bias.relative_attention_bias.weight.data.copy_(
291
  t5_dec_self_rel_bias
 
298
  print("Transferring decoder cross-attention relative position bias...")
299
  # Cross-attention relative position bias is in EncDecAttention of first block
300
  t5_dec_cross_rel_bias = (
301
+ cast(Any, t5_decoder.block[0]).layer[1].EncDecAttention.relative_attention_bias.weight.data
302
  )
303
  decoder.cross_relative_position_bias.relative_attention_bias.weight.data.copy_(
304
  t5_dec_cross_rel_bias
 
387
  num_layers = min(len(encoder.layers), len(llama.model.layers))
388
 
389
  for i in range(num_layers):
390
+ llama_layer = cast(Any, llama.model.layers[i])
391
+ enc_layer = cast(TransformerEncoderLayer, encoder.layers[i])
392
+ dec_layer = cast(TransformerDecoderLayer, decoder.layers[i])
393
 
394
  # --- Self-Attention ---
395
  # Llama: q_proj, k_proj, v_proj, o_proj
 
480
  if hasattr(tokenizer, "config") and hasattr(tokenizer.config, "max_length"):
481
  max_len = tokenizer.config.max_length
482
  elif hasattr(tokenizer, "model_max_length"):
483
+ max_len = cast(Any, tokenizer).model_max_length
484
  else:
485
  max_len = 512 # Default fallback
486
 
487
  # Cast activation to the literal type for mypy
488
  activation = cast(ActivationType, cfg.activation)
489
 
490
+ # Use cfg.vocab_size (32128) instead of tokenizer.vocab_size (32100)
491
+ # to match FLAN-T5's padded vocabulary
492
+ vocab_size = cfg.vocab_size if cfg.vocab_size is not None else tokenizer.vocab_size
493
+
494
  encoder = TransformerEncoder(
495
+ vocab_size=vocab_size,
496
  d_model=cfg.d_model,
497
  num_layers=cfg.num_encoder_layers,
498
  num_heads=cfg.num_attention_heads,
 
504
  use_learned_pos_enc=cfg.use_learned_pos_enc,
505
  activation=activation,
506
  use_relative_position_bias=cfg.use_relative_position_bias,
507
+ gradient_checkpointing=cfg.gradient_checkpointing,
508
  )
509
  decoder = TransformerDecoder(
510
+ vocab_size=vocab_size,
511
  d_model=cfg.d_model,
512
  num_layers=cfg.num_decoder_layers,
513
  num_heads=cfg.num_attention_heads,
 
519
  use_learned_pos_enc=cfg.use_learned_pos_enc,
520
  activation=activation,
521
  use_relative_position_bias=cfg.use_relative_position_bias,
522
+ gradient_checkpointing=cfg.gradient_checkpointing,
523
  )
524
 
525
  # Load pretrained weights if requested (but allow override for inference)
 
539
  )
540
  _load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
541
 
542
+ # T5 uses separate embeddings and lm_head (tie_word_embeddings=False)
543
+ # Both are initialized from pretrained weights if use_pretrained=True
544
+ # We do NOT tie them here - they remain independent for better flexibility
545
+
546
  model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
547
  model.add_head(
548
  "summarization",
549
+ LMHead(d_model=cfg.d_model, vocab_size=vocab_size, tie_embedding=decoder.embedding),
 
 
550
  )
551
  model.add_head(
552
  "emotion",
src/models/t5_layer_norm.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """T5-style Layer Normalization (RMSNorm without mean centering).
2
+
3
+ T5 uses a variant of RMSNorm that does NOT subtract the mean.
4
+ This is critical for matching T5's behavior.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class T5LayerNorm(nn.Module):
12
+ """
13
+ T5-style layer normalization without mean centering.
14
+
15
+ This is similar to RMSNorm but does NOT subtract the mean from x.
16
+ Formula: output = x / sqrt(mean(x^2) + eps) * weight
17
+
18
+ Args:
19
+ normalized_shape: Input shape (typically d_model)
20
+ eps: Small constant for numerical stability
21
+ """
22
+
23
+ def __init__(self, normalized_shape: int, eps: float = 1e-6):
24
+ super().__init__()
25
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
26
+ self.variance_epsilon = eps
27
+
28
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
29
+ """
30
+ Args:
31
+ hidden_states: (*, normalized_shape)
32
+
33
+ Returns:
34
+ Normalized tensor of same shape
35
+ """
36
+ # T5 uses variance = mean(x^2), does NOT subtract mean
37
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
38
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
39
+
40
+ # Scale by learned weight (no bias in T5)
41
+ return self.weight * hidden_states
src/training/early_stopping.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Early stopping implementation for training.
2
+
3
+ Author: Oliver Perrin
4
+ Date: December 2025
5
+ """
6
+
7
+
8
+ class EarlyStopping:
9
+ """Stop training when validation loss stops improving.
10
+
11
+ Args:
12
+ patience: Number of epochs to wait before stopping
13
+ min_delta: Minimum change to qualify as improvement
14
+ mode: 'min' for loss (lower is better), 'max' for accuracy
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ patience: int = 3,
20
+ min_delta: float = 0.001,
21
+ mode: str = "min"
22
+ ):
23
+ self.patience = patience
24
+ self.min_delta = min_delta
25
+ self.mode = mode
26
+ self.counter = 0
27
+ self.best_value = float('inf') if mode == 'min' else float('-inf')
28
+ self.early_stop = False
29
+
30
+ def __call__(self, metric_value: float) -> bool:
31
+ """Check if training should stop.
32
+
33
+ Args:
34
+ metric_value: Current metric value (e.g., validation loss)
35
+
36
+ Returns:
37
+ True if training should stop, False otherwise
38
+ """
39
+ if self.mode == 'min':
40
+ improved = metric_value < (self.best_value - self.min_delta)
41
+ else:
42
+ improved = metric_value > (self.best_value + self.min_delta)
43
+
44
+ if improved:
45
+ self.best_value = metric_value
46
+ self.counter = 0
47
+ return False
48
+
49
+ self.counter += 1
50
+ if self.counter >= self.patience:
51
+ self.early_stop = True
52
+ return True
53
+
54
+ return False
55
+
56
+ def reset(self):
57
+ """Reset early stopping state."""
58
+ self.counter = 0
59
+ self.best_value = float('inf') if self.mode == 'min' else float('-inf')
60
+ self.early_stop = False
src/training/gradient_monitor.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradient monitoring utilities.
2
+
3
+ Author: Oliver Perrin
4
+ Date: December 2025
5
+ """
6
+
7
+ from typing import Dict, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class GradientMonitor:
14
+ """Monitor gradient statistics during training.
15
+
16
+ Tracks gradient norms, helps detect gradient issues like vanishing/exploding.
17
+ """
18
+
19
+ def __init__(self, model: nn.Module, log_frequency: int = 100):
20
+ """Initialize gradient monitor.
21
+
22
+ Args:
23
+ model: Model to monitor
24
+ log_frequency: Log gradients every N steps
25
+ """
26
+ self.model = model
27
+ self.log_frequency = log_frequency
28
+ self.step_count = 0
29
+
30
+ def compute_grad_norm(self) -> Dict[str, float]:
31
+ """Compute gradient norm statistics.
32
+
33
+ Returns:
34
+ Dictionary with gradient statistics
35
+ """
36
+ total_norm = 0.0
37
+ max_norm = 0.0
38
+ num_params = 0
39
+
40
+ for p in self.model.parameters():
41
+ if p.grad is not None:
42
+ param_norm = p.grad.data.norm(2).item()
43
+ total_norm += param_norm ** 2
44
+ max_norm = max(max_norm, param_norm)
45
+ num_params += 1
46
+
47
+ total_norm = total_norm ** 0.5
48
+
49
+ return {
50
+ "grad_norm": total_norm,
51
+ "grad_norm_max": max_norm,
52
+ "num_params_with_grad": num_params,
53
+ }
54
+
55
+ def check_gradients(self) -> Dict[str, int]:
56
+ """Check for gradient issues (NaN, Inf, zero).
57
+
58
+ Returns:
59
+ Dictionary with counts of gradient issues
60
+ """
61
+ nan_count = 0
62
+ inf_count = 0
63
+ zero_count = 0
64
+
65
+ for p in self.model.parameters():
66
+ if p.grad is not None:
67
+ if torch.isnan(p.grad).any():
68
+ nan_count += 1
69
+ if torch.isinf(p.grad).any():
70
+ inf_count += 1
71
+ if (p.grad == 0).all():
72
+ zero_count += 1
73
+
74
+ return {
75
+ "nan_grads": nan_count,
76
+ "inf_grads": inf_count,
77
+ "zero_grads": zero_count,
78
+ }
79
+
80
+ def log_gradients(self, step: Optional[int] = None) -> Optional[Dict[str, float]]:
81
+ """Log gradient statistics if it's time.
82
+
83
+ Args:
84
+ step: Current training step (uses internal counter if None)
85
+
86
+ Returns:
87
+ Gradient statistics if logged, None otherwise
88
+ """
89
+ if step is None:
90
+ step = self.step_count
91
+ self.step_count += 1
92
+
93
+ if step % self.log_frequency == 0:
94
+ stats = self.compute_grad_norm()
95
+ issues = self.check_gradients()
96
+
97
+ # Combine stats
98
+ all_stats = {**stats, **issues}
99
+
100
+ return all_stats
101
+
102
+ return None
src/training/safe_compile.py CHANGED
@@ -1,86 +1,52 @@
1
- """
2
- Safe torch.compile configuration that prevents NaN issues.
3
 
4
- Author: Oliver Perrin
5
- Date: December 2025
6
- """
7
 
8
  import torch
9
 
10
 
 
 
 
 
 
 
 
 
11
  def compile_model_safe(
12
  model: torch.nn.Module,
13
  mode: str = "default",
 
14
  ) -> torch.nn.Module:
 
 
 
15
  """
16
- Compile model with inductor backend and safety guardrails.
17
 
18
- Uses 'default' mode which gives inductor speedups without CUDA graphs.
19
- CUDA graphs (reduce-overhead mode) don't work with dynamic shapes or
20
- shared embeddings like in T5.
21
 
22
- Args:
23
- model: Model to compile
24
- mode: Compilation mode ("default" recommended, avoid "reduce-overhead")
25
 
26
- Returns:
27
- Compiled model (or original if compilation fails)
28
- """
29
- if not torch.cuda.is_available():
30
- print(" CUDA not available, skipping compilation")
31
- return model
32
-
33
- try:
34
- # Configure for stability
35
- torch._dynamo.config.suppress_errors = True
36
- torch._dynamo.config.cache_size_limit = 64 # Allow more graph variations
37
-
38
- # Disable aggressive optimizations that can cause NaNs
39
- if hasattr(torch, "_inductor"):
40
- cfg = torch._inductor.config
41
- if hasattr(cfg, "epilogue_fusion"):
42
- cfg.epilogue_fusion = False
43
- if hasattr(cfg, "coordinate_descent_tuning"):
44
- cfg.coordinate_descent_tuning = False
45
- if hasattr(cfg, "force_fuse_int_mm_with_mul"):
46
- cfg.force_fuse_int_mm_with_mul = False
47
- # Explicitly disable CUDA graphs
48
- if hasattr(cfg, "triton"):
49
- if hasattr(cfg.triton, "cudagraphs"):
50
- cfg.triton.cudagraphs = False
51
- if hasattr(cfg.triton, "max_autotune_gemm"):
52
- cfg.triton.max_autotune_gemm = False
53
-
54
- # Compile with inductor (no CUDA graphs)
55
- compiled = torch.compile(model, mode=mode, fullgraph=False, dynamic=True)
56
- print(f"✓ Compiled with inductor ({mode} mode)")
57
- return compiled
58
-
59
- except Exception as e:
60
- print(f"⚠ Inductor compilation failed: {e}")
61
- print(" Falling back to aot_eager")
62
- try:
63
- return torch.compile(model, backend="aot_eager")
64
- except Exception:
65
- print(" Using uncompiled model")
66
- return model
67
-
68
-
69
- def apply_safe_config():
70
- """Apply safe configuration to torch._inductor before any compilation."""
71
- if hasattr(torch, "_inductor"):
72
- cfg = torch._inductor.config
73
- if hasattr(cfg, "epilogue_fusion"):
74
- cfg.epilogue_fusion = False
75
- if hasattr(cfg, "coordinate_descent_tuning"):
76
- cfg.coordinate_descent_tuning = False
77
- if hasattr(cfg, "triton"):
78
- if hasattr(cfg.triton, "cudagraphs"):
79
- cfg.triton.cudagraphs = False
80
- if hasattr(cfg.triton, "max_autotune_gemm"):
81
- cfg.triton.max_autotune_gemm = False
82
-
83
- # Dynamo config for stability
84
- torch._dynamo.config.suppress_errors = True
85
- torch._dynamo.config.cache_size_limit = 64
86
  print("✓ Applied safe inductor configuration")
 
1
+ """Safe defaults for `torch.compile` to reduce instability in tests and training."""
 
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
 
7
  import torch
8
 
9
 
10
+ def _set_attr(obj: object, name: str, value: Any) -> None:
11
+ """Set attribute on dynamic objects only if it exists (keeps static checkers quiet)."""
12
+
13
+ target = getattr(obj, name, None)
14
+ if target is not None:
15
+ setattr(obj, name, value)
16
+
17
+
18
  def compile_model_safe(
19
  model: torch.nn.Module,
20
  mode: str = "default",
21
+ dynamic: bool | None = None,
22
  ) -> torch.nn.Module:
23
+ """Safely compile model with inductor backend.
24
+
25
+ Parameters mirror `torch.compile` but default to conservative settings.
26
  """
 
27
 
28
+ return torch.compile(model, backend="inductor", mode=mode, dynamic=dynamic)
 
 
29
 
 
 
 
30
 
31
+ def apply_safe_config() -> None:
32
+ """Apply conservative torch._inductor and torch._dynamo settings if present."""
33
+
34
+ inductor = getattr(torch, "_inductor", None)
35
+ cfg = getattr(inductor, "config", None) if inductor is not None else None
36
+
37
+ if cfg is not None:
38
+ _set_attr(cfg, "epilogue_fusion", False)
39
+ _set_attr(cfg, "coordinate_descent_tuning", False)
40
+ triton_cfg = getattr(cfg, "triton", None)
41
+ if triton_cfg is not None:
42
+ _set_attr(triton_cfg, "cudagraphs", False)
43
+ _set_attr(triton_cfg, "max_autotune_gemm", False)
44
+
45
+ dynamo_cfg = getattr(torch, "_dynamo", None)
46
+ if dynamo_cfg is not None:
47
+ dyn_config = getattr(dynamo_cfg, "config", None)
48
+ if dyn_config is not None:
49
+ _set_attr(dyn_config, "suppress_errors", True)
50
+ _set_attr(dyn_config, "cache_size_limit", 64)
51
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  print("✓ Applied safe inductor configuration")
src/training/trainer.py CHANGED
@@ -2,7 +2,7 @@
2
  Multi-task Trainer for LexiMind.
3
 
4
  Handles training across summarization, emotion, and topic heads with mixed-precision,
5
- gradient accumulation, and MLflow logging.
6
 
7
  Author: Oliver Perrin
8
  Date: December 2025
@@ -10,6 +10,7 @@ Date: December 2025
10
 
11
  from __future__ import annotations
12
 
 
13
  import sys
14
  import time
15
  from collections import defaultdict
@@ -19,13 +20,36 @@ from typing import Any, Callable, Dict, List
19
  import mlflow
20
  import torch
21
  import torch.nn.functional as F
 
22
  from torch.utils.data import DataLoader
23
  from tqdm import tqdm
24
 
25
  from ..data.tokenization import Tokenizer
 
 
26
  from .metrics import accuracy, multilabel_f1, rouge_like
27
  from .nan_debugger import NaNDetector
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # --------------- Configuration ---------------
30
 
31
 
@@ -42,6 +66,15 @@ class TrainerConfig:
42
  experiment_name: str = "LexiMind"
43
  run_name: str | None = None
44
  gradient_accumulation_steps: int = 1
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  # --------------- Trainer ---------------
@@ -61,6 +94,8 @@ class Trainer:
61
  self.config = config
62
  self.device = device
63
  self.tokenizer = tokenizer
 
 
64
 
65
  # Task losses
66
  self.emotion_loss = torch.nn.BCEWithLogitsLoss()
@@ -76,6 +111,18 @@ class Trainer:
76
  self.nan_skip_count = 0
77
  self.max_nan_skips = 50
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Track current step for debugging
80
  self._current_step = 0
81
 
@@ -87,6 +134,46 @@ class Trainer:
87
  torch.backends.cuda.enable_flash_sdp(True)
88
  torch.backends.cuda.enable_mem_efficient_sdp(True)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # --------------- Training Loop ---------------
91
 
92
  def fit(
@@ -94,17 +181,24 @@ class Trainer:
94
  train_loaders: Dict[str, DataLoader],
95
  val_loaders: Dict[str, DataLoader] | None = None,
96
  checkpoint_callback: Callable | None = None,
 
97
  ) -> Dict[str, Dict[str, float]]:
98
  """Train model across all tasks with progress tracking."""
99
  history: Dict[str, Dict[str, float]] = {}
100
  total_start = time.perf_counter()
101
 
 
 
 
 
 
 
102
  with mlflow.start_run(run_name=self.config.run_name):
103
  self._log_config()
104
 
105
  # Epoch progress bar
106
  epoch_pbar = tqdm(
107
- range(1, self.config.max_epochs + 1),
108
  desc="Training",
109
  unit="epoch",
110
  position=0,
@@ -129,6 +223,15 @@ class Trainer:
129
  if "summarization" in val_loaders:
130
  self._validate_generation(val_loaders["summarization"], epoch)
131
 
 
 
 
 
 
 
 
 
 
132
  # Checkpoint
133
  if checkpoint_callback:
134
  checkpoint_callback(epoch, self.model, history)
@@ -256,7 +359,19 @@ class Trainer:
256
  return averaged
257
 
258
  def _optimizer_step(self) -> None:
259
- """Optimizer step with gradient clipping and NaN detection."""
 
 
 
 
 
 
 
 
 
 
 
 
260
  # Check gradients for NaN/Inf BEFORE clipping
261
  nan_grad = self.nan_detector.check_gradients(self._current_step)
262
  if nan_grad is not None:
@@ -280,6 +395,14 @@ class Trainer:
280
 
281
  self.optimizer.zero_grad()
282
 
 
 
 
 
 
 
 
 
283
  # Check parameters for NaN AFTER update
284
  nan_param = self.nan_detector.check_parameters(self._current_step)
285
  if nan_param is not None:
@@ -287,6 +410,31 @@ class Trainer:
287
  f"NaN in parameter {nan_param} after optimizer step at step {self._current_step}!"
288
  )
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  def _get_batch(
291
  self, iterators: Dict, loader: DataLoader, task: str
292
  ) -> Dict[str, torch.Tensor] | None:
@@ -341,6 +489,8 @@ class Trainer:
341
  inputs["src_mask"] = batch["src_mask"]
342
 
343
  logits = self.model.forward("summarization", inputs)
 
 
344
  loss = F.cross_entropy(
345
  logits.view(-1, logits.size(-1)),
346
  batch["labels"].view(-1),
@@ -348,6 +498,11 @@ class Trainer:
348
  label_smoothing=self.config.label_smoothing,
349
  )
350
 
 
 
 
 
 
351
  # Quick ROUGE estimate
352
  preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
353
  refs = self._decode_labels(batch["labels"])
 
2
  Multi-task Trainer for LexiMind.
3
 
4
  Handles training across summarization, emotion, and topic heads with mixed-precision,
5
+ gradient accumulation, gradient monitoring, early stopping, and MLflow logging.
6
 
7
  Author: Oliver Perrin
8
  Date: December 2025
 
10
 
11
  from __future__ import annotations
12
 
13
+ import math
14
  import sys
15
  import time
16
  from collections import defaultdict
 
20
  import mlflow
21
  import torch
22
  import torch.nn.functional as F
23
+ from torch.optim.lr_scheduler import LambdaLR
24
  from torch.utils.data import DataLoader
25
  from tqdm import tqdm
26
 
27
  from ..data.tokenization import Tokenizer
28
+ from .early_stopping import EarlyStopping
29
+ from .gradient_monitor import GradientMonitor
30
  from .metrics import accuracy, multilabel_f1, rouge_like
31
  from .nan_debugger import NaNDetector
32
 
33
+
34
+ def _get_cosine_schedule_with_warmup(
35
+ optimizer: torch.optim.Optimizer,
36
+ num_warmup_steps: int,
37
+ num_training_steps: int,
38
+ min_lr_ratio: float = 0.1,
39
+ ) -> LambdaLR:
40
+ """Create cosine LR schedule with linear warmup."""
41
+
42
+ def lr_lambda(current_step: int) -> float:
43
+ if current_step < num_warmup_steps:
44
+ return float(current_step) / float(max(1, num_warmup_steps))
45
+ progress = float(current_step - num_warmup_steps) / float(
46
+ max(1, num_training_steps - num_warmup_steps)
47
+ )
48
+ return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
49
+
50
+ return LambdaLR(optimizer, lr_lambda)
51
+
52
+
53
  # --------------- Configuration ---------------
54
 
55
 
 
66
  experiment_name: str = "LexiMind"
67
  run_name: str | None = None
68
  gradient_accumulation_steps: int = 1
69
+ # Learning rate scheduler
70
+ scheduler_type: str = "cosine" # "cosine", "linear", or "constant"
71
+ warmup_steps: int = 0
72
+ num_training_steps: int = 0 # Set automatically if 0
73
+ # Early stopping
74
+ early_stopping_patience: int | None = None # None = disabled
75
+ early_stopping_min_delta: float = 0.001
76
+ # Gradient monitoring
77
+ log_grad_norm_frequency: int = 100 # Log gradient norms every N steps
78
 
79
 
80
  # --------------- Trainer ---------------
 
94
  self.config = config
95
  self.device = device
96
  self.tokenizer = tokenizer
97
+ self.scheduler: LambdaLR | None = None # Set in fit()
98
+ self.global_step = 0 # Track global step for scheduler
99
 
100
  # Task losses
101
  self.emotion_loss = torch.nn.BCEWithLogitsLoss()
 
111
  self.nan_skip_count = 0
112
  self.max_nan_skips = 50
113
 
114
+ # Gradient monitoring
115
+ self.grad_monitor = GradientMonitor(model, log_frequency=config.log_grad_norm_frequency)
116
+
117
+ # Early stopping
118
+ self.early_stopping: EarlyStopping | None = None
119
+ if config.early_stopping_patience is not None:
120
+ self.early_stopping = EarlyStopping(
121
+ patience=config.early_stopping_patience,
122
+ min_delta=config.early_stopping_min_delta,
123
+ mode="min" # Lower loss is better
124
+ )
125
+
126
  # Track current step for debugging
127
  self._current_step = 0
128
 
 
134
  torch.backends.cuda.enable_flash_sdp(True)
135
  torch.backends.cuda.enable_mem_efficient_sdp(True)
136
 
137
+ def _setup_scheduler(self, train_loaders: Dict[str, DataLoader], start_epoch: int = 1) -> None:
138
+ """Initialize learning rate scheduler based on config."""
139
+ # Calculate steps per epoch once
140
+ max_batches = max(len(loader) for loader in train_loaders.values())
141
+ self.steps_per_epoch = max_batches // max(1, self.config.gradient_accumulation_steps)
142
+
143
+ if self.config.scheduler_type == "constant":
144
+ return # No scheduler needed
145
+
146
+ # Some tests pass a MagicMock optimizer without param_groups; skip scheduler gracefully
147
+ try:
148
+ _ = self.optimizer.param_groups # type: ignore[attr-defined]
149
+ except AttributeError:
150
+ self.scheduler = None
151
+ return
152
+
153
+ # Calculate total training steps
154
+ epochs_remaining = max(0, self.config.max_epochs - (start_epoch - 1))
155
+ num_training_steps = self.config.num_training_steps or (
156
+ self.steps_per_epoch * epochs_remaining
157
+ )
158
+
159
+ warmup_steps = self.config.warmup_steps
160
+ print(
161
+ f"✓ LR Scheduler: {self.config.scheduler_type} with {warmup_steps} warmup steps, {num_training_steps} total steps"
162
+ )
163
+
164
+ if self.config.scheduler_type == "cosine":
165
+ self.scheduler = _get_cosine_schedule_with_warmup(
166
+ self.optimizer, warmup_steps, num_training_steps
167
+ )
168
+ elif self.config.scheduler_type == "linear":
169
+
170
+ def linear_decay(step: int) -> float:
171
+ if step < warmup_steps:
172
+ return float(step) / float(max(1, warmup_steps))
173
+ return max(0.0, 1.0 - (step - warmup_steps) / (num_training_steps - warmup_steps))
174
+
175
+ self.scheduler = LambdaLR(self.optimizer, linear_decay)
176
+
177
  # --------------- Training Loop ---------------
178
 
179
  def fit(
 
181
  train_loaders: Dict[str, DataLoader],
182
  val_loaders: Dict[str, DataLoader] | None = None,
183
  checkpoint_callback: Callable | None = None,
184
+ start_epoch: int = 1,
185
  ) -> Dict[str, Dict[str, float]]:
186
  """Train model across all tasks with progress tracking."""
187
  history: Dict[str, Dict[str, float]] = {}
188
  total_start = time.perf_counter()
189
 
190
+ # Setup LR scheduler
191
+ self._setup_scheduler(train_loaders, start_epoch=start_epoch)
192
+ # Initialize global_step to reflect completed epochs when resuming
193
+ if hasattr(self, "steps_per_epoch"):
194
+ self.global_step = max(0, (start_epoch - 1) * self.steps_per_epoch)
195
+
196
  with mlflow.start_run(run_name=self.config.run_name):
197
  self._log_config()
198
 
199
  # Epoch progress bar
200
  epoch_pbar = tqdm(
201
+ range(start_epoch, self.config.max_epochs + 1),
202
  desc="Training",
203
  unit="epoch",
204
  position=0,
 
223
  if "summarization" in val_loaders:
224
  self._validate_generation(val_loaders["summarization"], epoch)
225
 
226
+ # Early stopping check
227
+ if self.early_stopping is not None:
228
+ val_loss = val_metrics.get("total_loss", val_metrics.get("summarization_loss", float('inf')))
229
+ if self.early_stopping(val_loss):
230
+ tqdm.write(f"\n⚠ Early stopping triggered at epoch {epoch}")
231
+ tqdm.write(f" Best validation loss: {self.early_stopping.best_value:.4f}")
232
+ tqdm.write(f" Patience exhausted ({self.early_stopping.patience} epochs)")
233
+ break
234
+
235
  # Checkpoint
236
  if checkpoint_callback:
237
  checkpoint_callback(epoch, self.model, history)
 
359
  return averaged
360
 
361
  def _optimizer_step(self) -> None:
362
+ """Perform optimizer step with gradient clipping."""
363
+ # Log gradient norms before clipping
364
+ grad_stats = self.grad_monitor.log_gradients(self.global_step)
365
+ if grad_stats is not None:
366
+ tqdm.write(
367
+ f" [Step {self.global_step}] "
368
+ f"Grad norm: {grad_stats['grad_norm']:.4f}, "
369
+ f"Max: {grad_stats['grad_norm_max']:.4f}"
370
+ )
371
+ # Log to MLflow
372
+ for key, val in grad_stats.items():
373
+ mlflow.log_metric(f"grad_{key}", val, step=self.global_step)
374
+
375
  # Check gradients for NaN/Inf BEFORE clipping
376
  nan_grad = self.nan_detector.check_gradients(self._current_step)
377
  if nan_grad is not None:
 
395
 
396
  self.optimizer.zero_grad()
397
 
398
+ # Step the learning rate scheduler
399
+ if self.scheduler is not None:
400
+ self.scheduler.step()
401
+ self.global_step += 1
402
+ # Log learning rate
403
+ current_lr = self.scheduler.get_last_lr()[0]
404
+ mlflow.log_metric("learning_rate", current_lr, step=self.global_step)
405
+
406
  # Check parameters for NaN AFTER update
407
  nan_param = self.nan_detector.check_parameters(self._current_step)
408
  if nan_param is not None:
 
410
  f"NaN in parameter {nan_param} after optimizer step at step {self._current_step}!"
411
  )
412
 
413
+ def _clip_embedding_gradients(self, max_norm: float = 5.0) -> None:
414
+ """Clip embedding gradients only if they exceed threshold.
415
+
416
+ Less aggressive clipping to allow learning while preventing
417
+ overflow with inductor backend + gradient accumulation.
418
+ """
419
+ for name, param in self.model.named_parameters():
420
+ if param.grad is not None and "embedding" in name.lower():
421
+ grad = param.grad
422
+ # Only fix actual NaN/Inf, don't preemptively clip
423
+ if torch.isnan(grad).any() or torch.isinf(grad).any():
424
+ # Count NaNs for monitoring
425
+ nan_count = torch.isnan(grad).sum().item()
426
+ inf_count = torch.isinf(grad).sum().item()
427
+ if nan_count > 0 or inf_count > 0:
428
+ # Replace with zeros only where invalid
429
+ param.grad = torch.where(
430
+ torch.isnan(grad) | torch.isinf(grad), torch.zeros_like(grad), grad
431
+ )
432
+ else:
433
+ # Normal gradient - only clip if extremely large
434
+ grad_norm = param.grad.norm()
435
+ if grad_norm > max_norm:
436
+ param.grad = param.grad * (max_norm / (grad_norm + 1e-6))
437
+
438
  def _get_batch(
439
  self, iterators: Dict, loader: DataLoader, task: str
440
  ) -> Dict[str, torch.Tensor] | None:
 
489
  inputs["src_mask"] = batch["src_mask"]
490
 
491
  logits = self.model.forward("summarization", inputs)
492
+
493
+ # Compute loss with proper masking
494
  loss = F.cross_entropy(
495
  logits.view(-1, logits.size(-1)),
496
  batch["labels"].view(-1),
 
498
  label_smoothing=self.config.label_smoothing,
499
  )
500
 
501
+ # Sanity check logits
502
+ if self.global_step % 100 == 0:
503
+ with torch.no_grad():
504
+ tqdm.write(f" [Step {self.global_step}] Summarization logits: mean={logits.mean().item():.2f}, std={logits.std().item():.2f}, loss={loss.item():.4f}")
505
+
506
  # Quick ROUGE estimate
507
  preds = self.tokenizer.decode_batch(logits.argmax(dim=-1).tolist())
508
  refs = self._decode_labels(batch["labels"])
tests/test_inference/test_pipeline.py CHANGED
@@ -2,11 +2,18 @@
2
 
3
  from __future__ import annotations
4
 
 
 
5
  from pathlib import Path
6
  from typing import cast
7
 
 
8
  import torch
9
 
 
 
 
 
10
  from src.data.tokenization import Tokenizer, TokenizerConfig
11
  from src.inference.pipeline import (
12
  EmotionPrediction,
@@ -16,6 +23,21 @@ from src.inference.pipeline import (
16
  )
17
  from src.utils.labels import LabelMetadata
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def _local_tokenizer_config() -> TokenizerConfig:
21
  root = Path(__file__).resolve().parents[2]
@@ -48,7 +70,7 @@ class DummyDecoder(torch.nn.Module):
48
  device: torch.device,
49
  **kwargs: object,
50
  ) -> torch.Tensor:
51
- seq = self.sequence.to(device)
52
  if seq.numel() > max_len:
53
  seq = seq[:max_len]
54
  batch = memory.size(0)
@@ -70,9 +92,9 @@ class DummyModel(torch.nn.Module):
70
  ) -> torch.Tensor: # pragma: no cover - simple dispatch
71
  batch = inputs["input_ids"].size(0)
72
  if task == "emotion":
73
- return self._emotion_logits.unsqueeze(0).repeat(batch, 1)
74
  if task == "topic":
75
- return self._topic_logits.unsqueeze(0).repeat(batch, 1)
76
  raise KeyError(task)
77
 
78
 
@@ -85,7 +107,7 @@ def _build_pipeline() -> InferencePipeline:
85
  tokenizer=tokenizer,
86
  emotion_labels=metadata.emotion,
87
  topic_labels=metadata.topic,
88
- config=InferenceConfig(summary_max_length=12),
89
  )
90
 
91
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import sys
6
+ import warnings
7
  from pathlib import Path
8
  from typing import cast
9
 
10
+ import pytest
11
  import torch
12
 
13
+ PROJECT_ROOT = Path(__file__).resolve().parents[2]
14
+ if str(PROJECT_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(PROJECT_ROOT))
16
+
17
  from src.data.tokenization import Tokenizer, TokenizerConfig
18
  from src.inference.pipeline import (
19
  EmotionPrediction,
 
23
  )
24
  from src.utils.labels import LabelMetadata
25
 
26
+ # Silence noisy DeprecationWarnings from underlying tokenizer bindings used in tests
27
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
28
+ warnings.filterwarnings(
29
+ "ignore",
30
+ message=r"builtin type SwigPy.*has no __module__ attribute",
31
+ category=DeprecationWarning,
32
+ )
33
+ warnings.filterwarnings(
34
+ "ignore",
35
+ category=DeprecationWarning,
36
+ module=r"importlib\\._bootstrap",
37
+ )
38
+
39
+ pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning")
40
+
41
 
42
  def _local_tokenizer_config() -> TokenizerConfig:
43
  root = Path(__file__).resolve().parents[2]
 
70
  device: torch.device,
71
  **kwargs: object,
72
  ) -> torch.Tensor:
73
+ seq = cast(torch.Tensor, self.sequence).to(device)
74
  if seq.numel() > max_len:
75
  seq = seq[:max_len]
76
  batch = memory.size(0)
 
92
  ) -> torch.Tensor: # pragma: no cover - simple dispatch
93
  batch = inputs["input_ids"].size(0)
94
  if task == "emotion":
95
+ return cast(torch.Tensor, self._emotion_logits).unsqueeze(0).repeat(batch, 1)
96
  if task == "topic":
97
+ return cast(torch.Tensor, self._topic_logits).unsqueeze(0).repeat(batch, 1)
98
  raise KeyError(task)
99
 
100
 
 
107
  tokenizer=tokenizer,
108
  emotion_labels=metadata.emotion,
109
  topic_labels=metadata.topic,
110
+ config=InferenceConfig(summary_max_length=12, summary_formatting=False),
111
  )
112
 
113
 
tests/test_models/test_visualizations.py CHANGED
@@ -34,7 +34,7 @@ def test_attention_visualization():
34
  V = torch.eye(seq_len, d_k).unsqueeze(0) # Identity-like
35
 
36
  # Compute attention
37
- output, weights = attention(Q, K, V, return_attn_weights=True)
38
 
39
  # Plot attention weights
40
  plt.figure(figsize=(8, 6))
 
34
  V = torch.eye(seq_len, d_k).unsqueeze(0) # Identity-like
35
 
36
  # Compute attention
37
+ _output, weights = attention(Q, K, V, return_attn_weights=True)
38
 
39
  # Plot attention weights
40
  plt.figure(figsize=(8, 6))