williyam commited on
Commit
7e07285
·
1 Parent(s): 5c82f3d

feat: GRPO fine-tuning pipeline + trained model for aerospace RAG

Browse files

- Add Jupyter notebook (agentic-rag-for-aerospace-research.ipynb) with
full GRPO training pipeline: dataset collection, baseline eval, training,
post-training eval, plots, and HF Hub push
- Add training/ package (config, dataset, reward, evaluate modules)
- Add train.py standalone training script
- Add training plots: training_curves.png, baseline_vs_trained.png,
score_distribution.png
- GRPO-trained Qwen2.5-0.5B with LoRA (r=16, alpha=32)
- Baseline: 0.558 -> Trained: 0.586 (+0.028 improvement)
- Model pushed to williyam/agentic-rag-aerospace-grpo on HF Hub
- Update README with training section, results, and plot embeds

.gitignore CHANGED
@@ -25,3 +25,7 @@ data/uploads/
25
  .DS_Store
26
  Thumbs.db
27
  node_modules/
 
 
 
 
 
25
  .DS_Store
26
  Thumbs.db
27
  node_modules/
28
+
29
+ # Training checkpoints (large files)
30
+ checkpoints/
31
+ .venv-1/
README.md CHANGED
@@ -291,6 +291,80 @@ Update `server/app.py` to use your domain config instead of `AerospaceDomainConf
291
 
292
  ---
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  ## Testing
295
 
296
  ```bash
@@ -344,9 +418,13 @@ agentic-rag-gym/
344
  ├── server/ # FastAPI + Gradio server
345
  ├── domains/aerospace/ # Aerospace research domain
346
  ├── domains/legal_research/ # Legal research domain (stub)
 
347
  ├── tests/ # Unit & integration tests (102+)
348
  ├── .github/workflows/ # CI pipeline
349
  ├── documents/ # Architecture & design docs
 
 
 
350
  ├── inference.py # Baseline inference script
351
  ├── openenv.yaml # OpenEnv specification
352
  ├── Dockerfile # Container definition
 
291
 
292
  ---
293
 
294
+ ## GRPO Fine-Tuning (Reinforcement Learning)
295
+
296
+ We fine-tune **Qwen2.5-0.5B-Instruct** using **Group Relative Policy Optimization (GRPO)** from TRL,
297
+ with LoRA adapters and the **real domain graders** as the reward signal — no proxy rewards.
298
+
299
+ ### Training Results
300
+
301
+ | Metric | Baseline | GRPO-Trained | Improvement |
302
+ |--------|----------|-------------|-------------|
303
+ | **Mean Score** | 0.5580 | 0.5860 | **+0.0280** |
304
+ | Propulsion Comparison | 0.508 | 0.562 | +0.053 |
305
+ | Debris Mitigation | 0.633 | 0.689 | +0.056 |
306
+ | Hypersonic Vehicle | 0.482 | 0.521 | +0.039 |
307
+ | Mars EDL | 0.574 | 0.568 | -0.006 |
308
+ | Life Support | 0.592 | 0.590 | -0.002 |
309
+
310
+ ### Training Curves
311
+
312
+ ![Training Curves](plots/training_curves.png)
313
+
314
+ ### Baseline vs. GRPO-Trained
315
+
316
+ ![Baseline vs Trained](plots/baseline_vs_trained.png)
317
+
318
+ ### Score Distribution
319
+
320
+ ![Score Distribution](plots/score_distribution.png)
321
+
322
+ ### Run Training (Notebook)
323
+
324
+ The primary training interface is the Jupyter notebook:
325
+
326
+ ```bash
327
+ jupyter notebook agentic-rag-for-aerospace-research.ipynb
328
+ ```
329
+
330
+ ### Run Training (Script)
331
+
332
+ For headless/CI environments:
333
+
334
+ ```bash
335
+ python train.py
336
+ ```
337
+
338
+ ### Configuration
339
+
340
+ | Parameter | Value |
341
+ |-----------|-------|
342
+ | Base Model | `Qwen/Qwen2.5-0.5B-Instruct` |
343
+ | Method | GRPO (Group Relative Policy Optimization) |
344
+ | LoRA | r=16, α=32, targets=q/k/v/o_proj |
345
+ | Optimizer | AdamW (torch) |
346
+ | Learning Rate | 5e-6 |
347
+ | Epochs | 2 |
348
+ | Group Size (G) | 4 |
349
+ | Max Completion | 512 tokens |
350
+ | Hardware | Apple M1 Pro (MPS) |
351
+ | Training Time | ~116 min |
352
+
353
+ ### Fine-Tuned Model
354
+
355
+ The GRPO-trained model is available on Hugging Face:
356
+ **[williyam/agentic-rag-aerospace-grpo](https://huggingface.co/williyam/agentic-rag-aerospace-grpo)**
357
+
358
+ ```python
359
+ from peft import AutoPeftModelForCausalLM
360
+ from transformers import AutoTokenizer
361
+
362
+ model = AutoPeftModelForCausalLM.from_pretrained("williyam/agentic-rag-aerospace-grpo")
363
+ tokenizer = AutoTokenizer.from_pretrained("williyam/agentic-rag-aerospace-grpo")
364
+ ```
365
+
366
+ ---
367
+
368
  ## Testing
369
 
370
  ```bash
 
418
  ├── server/ # FastAPI + Gradio server
419
  ├── domains/aerospace/ # Aerospace research domain
420
  ├── domains/legal_research/ # Legal research domain (stub)
421
+ ├── training/ # GRPO training package
422
  ├── tests/ # Unit & integration tests (102+)
423
  ├── .github/workflows/ # CI pipeline
424
  ├── documents/ # Architecture & design docs
425
+ ├── plots/ # Training curves & evaluation plots
426
+ ├── agentic-rag-for-aerospace-research.ipynb # GRPO training notebook
427
+ ├── train.py # Standalone training script
428
  ├── inference.py # Baseline inference script
429
  ├── openenv.yaml # OpenEnv specification
430
  ├── Dockerfile # Container definition
agentic-rag-for-aerospace-research.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8b90675048b0211601e7e8283ad1a2c5e41ff8f6fe5e0f5620b1be16d30993e
3
+ size 10531022
plots/baseline_vs_trained.png ADDED

Git LFS Details

  • SHA256: 62a092c9597ce70ea59b6dd87462ba53552c0f31504f3be615fedf03a0f219e3
  • Pointer size: 130 Bytes
  • Size of remote file: 85.1 kB
plots/eval_results.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "Qwen/Qwen2.5-0.5B-Instruct",
3
+ "finetuned": "williyam/agentic-rag-aerospace-grpo",
4
+ "training_time_min": 116.41443806489309,
5
+ "baseline": [
6
+ {
7
+ "task_id": "aero_hard_hypersonic_vehicle",
8
+ "task_name": "Reusable Hypersonic Space Access Vehicle",
9
+ "difficulty": "hard",
10
+ "answer": "### Conceptual Design Overview\n\n#### Introduction\nReusability of a hypersonic space access vehicle (SAV) with a combined-cycle propulsion (CCP) system is crucial for reducing operational costs, increasing payload capacity, and enhancing safety. The proposed SAV utilizes a combination of scramjet, turbine, and rocket engines, designed to achieve high-speed, long-duration hypersonic travel. The propulsion mode transitions between scramjets and rockets, allowing for efficient use of both technologi",
11
+ "score": 0.48219047619047617
12
+ },
13
+ {
14
+ "task_id": "aero_easy_propulsion_comparison",
15
+ "task_name": "Compare Propulsion Technologies",
16
+ "difficulty": "easy",
17
+ "answer": "### Comparison of Propulsion Technologies for a Mars Transit Mission\n\n#### Specific Impulse and Thrust Levels\n\n**Ion Propulsion (IP):**\n- **Specific Impulse:** 4,190 seconds (NASA)\n- **Thrust Level:** Very Low (0.5 N for NEXT)\n\n**Nuclear Thermal Propulsion (NTP):**\n- **Specific Impulse:** 850-1,000 seconds (NASA)\n- **Thrust Level:** Extremely High (approximately 1,000 seconds)\n\n**Comparison:**\n- **Specific Impulse:** IP > NTP\n- **Thrust Level:** IP < NTP\n\n**Impact on Mission Duration:**\n- **IP:*",
18
+ "score": 0.5082857142857143
19
+ },
20
+ {
21
+ "task_id": "aero_medium_mars_edl",
22
+ "task_name": "Mars EDL Architecture Design",
23
+ "difficulty": "medium",
24
+ "answer": "### Overview of the EDL Architecture Design\n\n#### 1. **Entry Phase**\n - **Aerocapture**: \n - **Objective**: To capture the spacecraft into the Martian atmosphere.\n - **Data**: The spacecraft must be designed to achieve a high velocity relative to the Martian atmosphere to avoid atmospheric drag. This typically involves using a combination of parachutes, airbags, and other aerodynamic devices.\n - **Quantitative Evidence**: The aerocapture trajectory should be optimized to minimize e",
25
+ "score": 0.5742857142857143
26
+ },
27
+ {
28
+ "task_id": "aero_medium_life_support",
29
+ "task_name": "Deep Space Life Support Design",
30
+ "difficulty": "medium",
31
+ "answer": "### Deep Space Life Support Design for a 6-Person Crew on a 2.5-Year Mars Mission\n\n#### Overview\nThis design integrates physicochemical and bioregenerative systems to maximize closure ratios and address various aspects of life support on a 2.5-year Mars mission. The goal is to ensure the survival and health of the crew while minimizing resource consumption and environmental impact.\n\n#### System Components\n1. **Physicochemical Life Support (PLS)**\n - **Thermophilic Bacteria for Waste Liquefacti",
32
+ "score": 0.5917142857142857
33
+ },
34
+ {
35
+ "task_id": "aero_easy_debris_mitigation",
36
+ "task_name": "Space Debris Mitigation Strategies",
37
+ "difficulty": "easy",
38
+ "answer": "### Space Debris Mitigation Strategies\n\n#### Current Challenges and Active Debris Removal Technologies\n\n**Current Challenges:**\n1. **Kessler Syndrome Risk:** Over 36,500 objects >10 cm in LEO reach critical densities, leading to significant collisions that threaten orbit stability.\n2. **Tracking Capabilities:** Limited ground-based and orbital tracking capabilities, especially for larger objects like satellites and asteroids.\n3. **Proposed Removal Methods:**\n - **Electrodynamic Tethers:** Gene",
39
+ "score": 0.6334285714285715
40
+ }
41
+ ],
42
+ "trained": [
43
+ {
44
+ "task_id": "aero_hard_hypersonic_vehicle",
45
+ "task_name": "Reusable Hypersonic Space Access Vehicle",
46
+ "difficulty": "hard",
47
+ "answer": "### Conceptual Design Overview\n\n#### Introduction\nReusability of a hypersonic space access vehicle (SAV) with a combined-cycle propulsion (CCP) system is crucial for reducing operational costs, increasing payload capacity, and enhancing safety. The proposed SAV utilizes a combination of scramjet, turbine, and rocket propulsion modes, aiming to achieve high-speed, long-duration flights within the hypersonic regime. This approach addresses key aerothermodynamic challenges such as shock-boundary la",
48
+ "score": 0.5214285714285715
49
+ },
50
+ {
51
+ "task_id": "aero_easy_propulsion_comparison",
52
+ "task_name": "Compare Propulsion Technologies",
53
+ "difficulty": "easy",
54
+ "answer": "### Comparison of Propulsion Technologies for a Mars Transit Mission\n\n#### Specific Impulse\n**Ion Propulsion (IP):**\n- **Specific Impulse:** 4,190 seconds\n- **Thrust Level:** Extremely Low (0.5 N)\n- **Impact on Mission Duration:** Minimal impact on mission duration as IP is designed for short-duration missions.\n- **Technology Readiness:** High due to its ability to achieve high specific impulses without significant technological hurdles.\n\n**Nuclear Thermal Propulsion (NTP):**\n- **Specific Impuls",
55
+ "score": 0.5616190476190477
56
+ },
57
+ {
58
+ "task_id": "aero_medium_mars_edl",
59
+ "task_name": "Mars EDL Architecture Design",
60
+ "difficulty": "medium",
61
+ "answer": "### Overview of the EDL Architecture Design\n\n#### Introduction\nThe EDL architecture design addresses the challenges posed by the Mars atmospheric environment, including the high peak heat flux and the need for effective aerodynamic deceleration of heavy payloads. This design integrates multiple phases into a single, cohesive approach to ensure efficient and safe mission execution. The proposed architecture includes:\n\n1. **Aerocapture**: The first phase involves capturing the payload using aerobr",
62
+ "score": 0.5678095238095239
63
+ },
64
+ {
65
+ "task_id": "aero_medium_life_support",
66
+ "task_name": "Deep Space Life Support Design",
67
+ "difficulty": "medium",
68
+ "answer": "### Overview of the Integrated Life Support System Design\n\n#### Introduction\nDeep space exploration presents unique challenges due to the extreme conditions encountered, including low gravity, radiation exposure, and the presence of hazardous materials like water and carbon dioxide. To ensure long-term survival and minimize environmental impact, a sophisticated life support system must be designed. This design integrates various biological processes and energy sources to maximize closure ratios ",
69
+ "score": 0.5897142857142857
70
+ },
71
+ {
72
+ "task_id": "aero_easy_debris_mitigation",
73
+ "task_name": "Space Debris Mitigation Strategies",
74
+ "difficulty": "easy",
75
+ "answer": "**Task: Space Debris Mitigation Strategies**\n\n**Current Challenges and Evaluation of Active Debris Removal Technologies**\n\nActive debris removal (ADR) technologies have been developed to mitigate the effects of space debris on orbiting satellites and other spacecraft. These technologies aim to prevent collisions that could lead to catastrophic damage to spacecraft and disrupt satellite communications. However, the effectiveness of these technologies varies widely depending on several factors suc",
76
+ "score": 0.6894285714285716
77
+ }
78
+ ],
79
+ "summary": {
80
+ "baseline_mean": 0.5579809523809524,
81
+ "trained_mean": 0.5860000000000001,
82
+ "improvement": 0.02801904761904772
83
+ }
84
+ }
plots/score_distribution.png ADDED

Git LFS Details

  • SHA256: d692edf8c4b2b23fbb12efc4ac98580e4f0b1cd526c22472a59625247e613778
  • Pointer size: 130 Bytes
  • Size of remote file: 36.9 kB
plots/training_curves.png ADDED

Git LFS Details

  • SHA256: 7d2645fe02ed361c359ca922be193ba33723e89d23162b260d1452eed27cc361
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
train.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ train.py — GRPO Fine-Tuning for Agentic RAG Gym (Aerospace Domain)
4
+ ====================================================================
5
+
6
+ End-to-end training script that:
7
+ 1. Connects to the live gym environment to collect prompts
8
+ 2. Loads Qwen2.5-0.5B-Instruct with LoRA
9
+ 3. Trains with GRPO (Group Relative Policy Optimization) via TRL
10
+ 4. Evaluates baseline vs. trained model with domain graders
11
+ 5. Generates publication-quality plots
12
+ 6. Pushes the fine-tuned model to Hugging Face Hub
13
+
14
+ Usage:
15
+ # Start the environment first:
16
+ python main.py &
17
+
18
+ # Then run training:
19
+ python train.py
20
+
21
+ Environment variables (loaded from .env):
22
+ HF_TOKEN Hugging Face token (for model push)
23
+ HF_USERNAME Hugging Face username (default: williyam)
24
+ ENV_URL Gym environment URL (default: http://localhost:7860)
25
+ BASE_MODEL_ID Base model (default: Qwen/Qwen2.5-0.5B-Instruct)
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import sys
31
+ import time
32
+
33
+ import numpy as np
34
+ import torch
35
+ from dotenv import load_dotenv
36
+ from peft import LoraConfig, TaskType
37
+
38
+ load_dotenv()
39
+
40
+ from training.config import (
41
+ BASE_MODEL_ID,
42
+ CHECKPOINTS_DIR,
43
+ FINETUNED_MODEL_ID,
44
+ HF_TOKEN,
45
+ PLOTS_DIR,
46
+ )
47
+ from training.dataset import SYSTEM_PROMPT, build_grpo_dataset
48
+ from training.evaluate import (
49
+ evaluate_model_on_tasks,
50
+ plot_baseline_vs_trained,
51
+ plot_reward_distribution,
52
+ plot_training_curves,
53
+ save_eval_results,
54
+ )
55
+ from training.reward import grade_answer_sync
56
+
57
+ # ── Device ─────────────────────────────────────────────────────────────
58
+ if torch.backends.mps.is_available():
59
+ DEVICE = "mps"
60
+ elif torch.cuda.is_available():
61
+ DEVICE = "cuda"
62
+ else:
63
+ DEVICE = "cpu"
64
+ print(f"Device: {DEVICE} | PyTorch {torch.__version__}")
65
+
66
+ # ── Build dataset ──────────────────────────────────────────────────────
67
+ print("\n[1/6] Collecting prompts from environment...")
68
+ dataset = build_grpo_dataset(num_per_task=8, seed=42)
69
+
70
+ task_ids_map = {}
71
+ for row in dataset:
72
+ task_ids_map[row["task_id"]] = row["task_name"]
73
+
74
+ # Format for TRL: prompt column must be list of message dicts
75
+ def format_for_trl(example):
76
+ return {
77
+ "prompt": [
78
+ {"role": "system", "content": SYSTEM_PROMPT},
79
+ {"role": "user", "content": example["prompt"]},
80
+ ],
81
+ }
82
+
83
+ train_dataset = dataset.map(format_for_trl, remove_columns=["task_name", "difficulty"])
84
+
85
+ # ── Load model ─────────────────────────────────────────────────────────
86
+ print(f"\n[2/6] Loading model: {BASE_MODEL_ID}")
87
+ from transformers import AutoModelForCausalLM, AutoTokenizer
88
+
89
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True, padding_side="left")
90
+ if tokenizer.pad_token is None:
91
+ tokenizer.pad_token = tokenizer.eos_token
92
+
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ BASE_MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True,
95
+ )
96
+
97
+ peft_config = LoraConfig(
98
+ task_type=TaskType.CAUSAL_LM,
99
+ r=16, lora_alpha=32, lora_dropout=0.05,
100
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
101
+ bias="none",
102
+ )
103
+
104
+ # ── Baseline evaluation ───────────────────────────────────────────────
105
+ print("\n[3/6] Evaluating baseline (before training)...")
106
+ from peft import get_peft_model
107
+ eval_model = get_peft_model(model.to(DEVICE), peft_config)
108
+ eval_model.print_trainable_parameters()
109
+
110
+ eval_prompts = []
111
+ seen = set()
112
+ for row in dataset:
113
+ if row["task_id"] not in seen:
114
+ eval_prompts.append(row)
115
+ seen.add(row["task_id"])
116
+
117
+ baseline_results = evaluate_model_on_tasks(
118
+ eval_model, tokenizer, eval_prompts, max_new_tokens=512, temperature=0.1,
119
+ )
120
+ baseline_mean = np.mean([r["score"] for r in baseline_results])
121
+ print(f" Baseline mean score: {baseline_mean:.4f}")
122
+ del eval_model
123
+ model = model.to("cpu")
124
+
125
+ # ── Reward function ────────────────────────────────────────────────────
126
+ def reward_fn(completions, **kwargs):
127
+ """Score completions using domain graders."""
128
+ rewards = []
129
+ prompts = kwargs.get("prompts", [])
130
+ for i, completion in enumerate(completions):
131
+ text = completion[0]["content"] if isinstance(completion, list) else str(completion)
132
+ text = text.strip()
133
+ if len(text) < 10:
134
+ rewards.append(0.01)
135
+ continue
136
+ task_id = None
137
+ if i < len(prompts):
138
+ p = prompts[i]
139
+ if isinstance(p, list):
140
+ p = " ".join(m.get("content", "") for m in p)
141
+ for tid, name in task_ids_map.items():
142
+ if name in str(p):
143
+ task_id = tid
144
+ break
145
+ if task_id is None:
146
+ task_id = list(task_ids_map.keys())[0]
147
+ try:
148
+ rewards.append(float(grade_answer_sync(task_id, text)))
149
+ except Exception:
150
+ rewards.append(0.01)
151
+ return rewards
152
+
153
+ # ── GRPO Training ──────────────────────────────────────────────────────
154
+ print("\n[4/6] Starting GRPO training...")
155
+ from trl import GRPOConfig, GRPOTrainer
156
+
157
+ training_args = GRPOConfig(
158
+ output_dir=str(CHECKPOINTS_DIR / "grpo"),
159
+ num_train_epochs=2,
160
+ per_device_train_batch_size=1,
161
+ gradient_accumulation_steps=4,
162
+ learning_rate=5e-6,
163
+ warmup_ratio=0.1,
164
+ max_grad_norm=1.0,
165
+ logging_steps=1,
166
+ save_steps=50,
167
+ save_total_limit=2,
168
+ bf16=False,
169
+ fp16=False,
170
+ seed=42,
171
+ remove_unused_columns=False,
172
+ num_generations=4,
173
+ max_completion_length=512,
174
+ temperature=0.7,
175
+ use_vllm=False,
176
+ report_to="none",
177
+ optim="adamw_torch",
178
+ gradient_checkpointing=True,
179
+ log_completions=True,
180
+ num_completions_to_print=1,
181
+ )
182
+
183
+ trainer = GRPOTrainer(
184
+ model=BASE_MODEL_ID,
185
+ reward_funcs=reward_fn,
186
+ args=training_args,
187
+ train_dataset=train_dataset,
188
+ peft_config=peft_config,
189
+ processing_class=tokenizer,
190
+ )
191
+
192
+ t0 = time.time()
193
+ trainer.train()
194
+ elapsed = time.time() - t0
195
+ print(f"\nTraining completed in {elapsed/60:.1f} min")
196
+
197
+ # ── Post-training evaluation ──────────────────────────────────────────
198
+ print("\n[5/6] Evaluating trained model...")
199
+ trained_results = evaluate_model_on_tasks(
200
+ trainer.model, tokenizer, eval_prompts, max_new_tokens=512, temperature=0.1,
201
+ )
202
+ trained_mean = np.mean([r["score"] for r in trained_results])
203
+ print(f" Trained mean score: {trained_mean:.4f}")
204
+ print(f" Improvement: {trained_mean - baseline_mean:+.4f}")
205
+
206
+ # ── Plots + save ──────────────────────────────────────────────────────
207
+ print("\n[6/6] Generating plots and saving...")
208
+ log_history = trainer.state.log_history if hasattr(trainer, "state") else []
209
+ plot_training_curves(log_history)
210
+ plot_baseline_vs_trained(baseline_results, trained_results)
211
+ plot_reward_distribution(baseline_results, trained_results)
212
+ save_eval_results(baseline_results, trained_results)
213
+
214
+ # Push
215
+ if HF_TOKEN:
216
+ print(f"\nPushing to HF Hub: {FINETUNED_MODEL_ID}")
217
+ trainer.model.push_to_hub(FINETUNED_MODEL_ID, token=HF_TOKEN, private=False)
218
+ tokenizer.push_to_hub(FINETUNED_MODEL_ID, token=HF_TOKEN, private=False)
219
+ print("Model pushed successfully")
220
+
221
+ print(f"\n{'='*50}")
222
+ print(f" Base model: {BASE_MODEL_ID}")
223
+ print(f" Training time: {elapsed/60:.1f} min")
224
+ print(f" Baseline score: {baseline_mean:.4f}")
225
+ print(f" Trained score: {trained_mean:.4f}")
226
+ print(f" Delta: {trained_mean - baseline_mean:+.4f}")
227
+ print(f"{'='*50}")
training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Training utilities for Agentic RAG Gym GRPO fine-tuning."""
training/config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for GRPO training of Agentic RAG Gym.
3
+ All secrets and tunables are loaded from environment variables / .env file.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ from pathlib import Path
10
+
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Paths
17
+ # ---------------------------------------------------------------------------
18
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
19
+ PLOTS_DIR = PROJECT_ROOT / "plots"
20
+ PLOTS_DIR.mkdir(exist_ok=True)
21
+ CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
22
+ CHECKPOINTS_DIR.mkdir(exist_ok=True)
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Environment / secrets
26
+ # ---------------------------------------------------------------------------
27
+ HF_TOKEN: str = os.getenv("HF_TOKEN", "")
28
+ HF_USERNAME: str = os.getenv("HF_USERNAME", "williyam")
29
+
30
+ # Environment server (our FastAPI gym)
31
+ ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860")
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Model configuration
35
+ # ---------------------------------------------------------------------------
36
+ BASE_MODEL_ID: str = os.getenv("BASE_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
37
+ FINETUNED_MODEL_ID: str = os.getenv(
38
+ "FINETUNED_MODEL_ID",
39
+ f"{HF_USERNAME}/agentic-rag-aerospace-grpo",
40
+ )
training/dataset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset builder for GRPO training.
3
+
4
+ Connects to the live Agentic RAG Gym environment to build training prompts.
5
+ Each prompt = system instruction + task description + retrieved documents.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import re
12
+ from typing import Any, Dict, List
13
+
14
+ import httpx
15
+ from datasets import Dataset
16
+
17
+ from training.config import ENV_URL
18
+
19
+ SYSTEM_PROMPT = (
20
+ "You are an expert aerospace research analyst with deep knowledge of "
21
+ "propulsion systems, orbital mechanics, materials science, thermal protection, "
22
+ "life support systems, and space mission design. When analyzing aerospace topics:\n"
23
+ "- Cite specific data points and numerical values from provided documents\n"
24
+ "- Structure your analysis with clear sections\n"
25
+ "- Compare alternatives with quantitative evidence\n"
26
+ "- Provide actionable recommendations grounded in engineering constraints\n"
27
+ )
28
+
29
+
30
+ def _format_prompt(task: Dict[str, Any], docs: List[Dict[str, Any]]) -> str:
31
+ """Build a user message from a task + retrieved docs."""
32
+ doc_text = ""
33
+ for i, doc in enumerate(docs, 1):
34
+ src = doc.get("source", "unknown")
35
+ doc_text += f"\n[Document {i} -- {src}]\n{doc['content']}\n"
36
+
37
+ return (
38
+ f"## Task: {task['name']}\n\n"
39
+ f"{task['description']}\n\n"
40
+ f"### Retrieved Reference Documents\n{doc_text}\n"
41
+ "### Instructions\n"
42
+ "Provide a comprehensive, well-structured answer to the task above. "
43
+ "Cite specific data from the reference documents. "
44
+ "Include quantitative evidence and clear recommendations."
45
+ )
46
+
47
+
48
+ async def _fetch_tasks(client: httpx.AsyncClient) -> List[Dict[str, Any]]:
49
+ resp = await client.get(f"{ENV_URL}/tasks")
50
+ resp.raise_for_status()
51
+ return resp.json()["tasks"]
52
+
53
+
54
+ def _query_variants(task: Dict[str, Any]) -> List[str]:
55
+ """Generate diverse retrieval queries from a task description."""
56
+ desc = task["description"]
57
+ sentences = [s.strip() for s in re.split(r'[.!?]+', desc) if len(s.strip()) > 20]
58
+ variants = [desc]
59
+ variants.extend(sentences[:4])
60
+ name_words = task["name"].lower()
61
+ variants.append(name_words)
62
+ return variants
63
+
64
+
65
+ async def _collect_one_prompt(
66
+ client: httpx.AsyncClient,
67
+ task: Dict[str, Any],
68
+ query: str,
69
+ ) -> Dict[str, Any] | None:
70
+ """Reset env, retrieve docs, build a prompt."""
71
+ resp = await client.post(f"{ENV_URL}/reset", json={"task_id": task["task_id"]})
72
+ if resp.status_code != 200:
73
+ return None
74
+
75
+ resp = await client.post(
76
+ f"{ENV_URL}/step", json={"type": "retrieve", "query": query}
77
+ )
78
+ if resp.status_code != 200:
79
+ return None
80
+
81
+ docs = resp.json()["observation"]["retrieved_docs"]
82
+ user_msg = _format_prompt(task, docs)
83
+ return {
84
+ "task_id": task["task_id"],
85
+ "task_name": task["name"],
86
+ "difficulty": task.get("difficulty", "easy"),
87
+ "prompt": user_msg,
88
+ }
89
+
90
+
91
+ async def _build_dataset_async(num_per_task: int = 8) -> List[Dict[str, Any]]:
92
+ records: List[Dict[str, Any]] = []
93
+ async with httpx.AsyncClient(timeout=120.0) as client:
94
+ tasks = await _fetch_tasks(client)
95
+ for task in tasks:
96
+ variants = _query_variants(task)
97
+ for i in range(num_per_task):
98
+ query = variants[i % len(variants)]
99
+ rec = await _collect_one_prompt(client, task, query)
100
+ if rec:
101
+ records.append(rec)
102
+ return records
103
+
104
+
105
+ def build_grpo_dataset(num_per_task: int = 8, seed: int = 42) -> Dataset:
106
+ """
107
+ Build a HuggingFace Dataset of prompts for GRPO training.
108
+
109
+ Each row: prompt (str), task_id, task_name, difficulty.
110
+ Prompts are collected from the LIVE environment (reset + retrieve).
111
+ """
112
+ records = asyncio.run(_build_dataset_async(num_per_task))
113
+ if not records:
114
+ raise RuntimeError(f"No prompts collected. Is the environment running at {ENV_URL}?")
115
+ ds = Dataset.from_list(records)
116
+ ds = ds.shuffle(seed=seed)
117
+ print(f"Built GRPO dataset: {len(ds)} prompts across {len(ds.unique('task_id'))} tasks")
118
+ return ds
training/evaluate.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation & plotting utilities for GRPO training.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from collections import defaultdict
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List
11
+
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+
17
+ from training.config import PLOTS_DIR
18
+ from training.reward import grade_answer_sync
19
+
20
+
21
+ # ── Model evaluation ────────────────────────────────────────────────────
22
+
23
+ def evaluate_model_on_tasks(
24
+ model,
25
+ tokenizer,
26
+ prompts: List[Dict[str, Any]],
27
+ max_new_tokens: int = 512,
28
+ temperature: float = 0.1,
29
+ ) -> List[Dict[str, Any]]:
30
+ """Generate answers for each prompt and grade them with domain graders."""
31
+ import torch
32
+ results: List[Dict[str, Any]] = []
33
+ device = next(model.parameters()).device
34
+
35
+ for item in prompts:
36
+ messages = [
37
+ {"role": "system", "content": "You are an expert aerospace research analyst. Provide comprehensive answers citing specific data."},
38
+ {"role": "user", "content": item["prompt"]},
39
+ ]
40
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
41
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048)
42
+ inputs = {k: v.to(device) for k, v in inputs.items()}
43
+
44
+ with torch.no_grad():
45
+ output_ids = model.generate(
46
+ **inputs,
47
+ max_new_tokens=max_new_tokens,
48
+ temperature=max(temperature, 0.01),
49
+ do_sample=temperature > 0,
50
+ top_p=0.9,
51
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
52
+ )
53
+
54
+ generated = output_ids[0][inputs["input_ids"].shape[1]:]
55
+ answer = tokenizer.decode(generated, skip_special_tokens=True).strip()
56
+ score = grade_answer_sync(item["task_id"], answer)
57
+
58
+ results.append({
59
+ "task_id": item["task_id"],
60
+ "task_name": item["task_name"],
61
+ "difficulty": item["difficulty"],
62
+ "answer": answer[:500],
63
+ "score": score,
64
+ })
65
+ print(f" [{item['task_id']}] score={score:.3f} len={len(answer)}")
66
+ return results
67
+
68
+
69
+ # ── Plotting ─────────────────────────────────────────────────────────────
70
+
71
+ def plot_training_curves(log_history: List[Dict[str, Any]], out_dir: Path = PLOTS_DIR) -> Path:
72
+ """Plot training loss and reward curves. Returns path to saved figure."""
73
+ steps = [e["step"] for e in log_history if "loss" in e]
74
+ losses = [e["loss"] for e in log_history if "loss" in e]
75
+ reward_steps = [e["step"] for e in log_history if "reward" in e or "reward/mean" in e]
76
+ rewards = [e.get("reward/mean", e.get("reward")) for e in log_history if "reward" in e or "reward/mean" in e]
77
+
78
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
79
+
80
+ ax = axes[0]
81
+ if steps and losses:
82
+ ax.plot(steps, losses, color="#D4AF37", linewidth=2)
83
+ ax.set_xlabel("Training Step", fontsize=12)
84
+ ax.set_ylabel("Loss", fontsize=12)
85
+ ax.set_title("GRPO Training Loss", fontsize=14, fontweight="bold")
86
+ ax.grid(True, alpha=0.3)
87
+
88
+ ax = axes[1]
89
+ if reward_steps and rewards:
90
+ ax.plot(reward_steps, rewards, color="#4CAF50", linewidth=2)
91
+ ax.set_xlabel("Training Step", fontsize=12)
92
+ ax.set_ylabel("Mean Reward (grader score)", fontsize=12)
93
+ ax.set_title("GRPO Mean Reward", fontsize=14, fontweight="bold")
94
+ ax.grid(True, alpha=0.3)
95
+
96
+ plt.tight_layout()
97
+ path = out_dir / "training_curves.png"
98
+ fig.savefig(path, dpi=150, bbox_inches="tight")
99
+ plt.close(fig)
100
+ print(f"Saved training curves -> {path}")
101
+ return path
102
+
103
+
104
+ def plot_baseline_vs_trained(
105
+ baseline_results: List[Dict[str, Any]],
106
+ trained_results: List[Dict[str, Any]],
107
+ out_dir: Path = PLOTS_DIR,
108
+ ) -> Path:
109
+ """Bar chart comparing baseline vs trained scores per task."""
110
+ def _agg(results):
111
+ sums: Dict[str, List[float]] = defaultdict(list)
112
+ for r in results:
113
+ sums[r["task_id"]].append(r["score"])
114
+ return {k: float(np.mean(v)) for k, v in sums.items()}
115
+
116
+ baseline_agg = _agg(baseline_results)
117
+ trained_agg = _agg(trained_results)
118
+ tasks = sorted(set(baseline_agg) | set(trained_agg))
119
+ short_names = [t.replace("aero_", "").replace("_", " ").title() for t in tasks]
120
+
121
+ x = np.arange(len(tasks))
122
+ width = 0.35
123
+ fig, ax = plt.subplots(figsize=(12, 6))
124
+ bars1 = ax.bar(x - width / 2, [baseline_agg.get(t, 0) for t in tasks],
125
+ width, label="Baseline (untrained)", color="#8B0000", alpha=0.85, edgecolor="black")
126
+ bars2 = ax.bar(x + width / 2, [trained_agg.get(t, 0) for t in tasks],
127
+ width, label="GRPO-trained", color="#D4AF37", alpha=0.85, edgecolor="black")
128
+
129
+ ax.set_xlabel("Task", fontsize=12)
130
+ ax.set_ylabel("Grader Score (0-1)", fontsize=12)
131
+ ax.set_title("Baseline vs. GRPO-Trained Model - Task Scores", fontsize=14, fontweight="bold")
132
+ ax.set_xticks(x)
133
+ ax.set_xticklabels(short_names, rotation=20, ha="right", fontsize=10)
134
+ ax.legend(fontsize=11)
135
+ ax.set_ylim(0, 1.0)
136
+ ax.grid(axis="y", alpha=0.3)
137
+ for bar in list(bars1) + list(bars2):
138
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
139
+ f"{bar.get_height():.2f}", ha="center", fontsize=9)
140
+
141
+ plt.tight_layout()
142
+ path = out_dir / "baseline_vs_trained.png"
143
+ fig.savefig(path, dpi=150, bbox_inches="tight")
144
+ plt.close(fig)
145
+ print(f"Saved comparison chart -> {path}")
146
+ return path
147
+
148
+
149
+ def plot_reward_distribution(
150
+ baseline_results: List[Dict[str, Any]],
151
+ trained_results: List[Dict[str, Any]],
152
+ out_dir: Path = PLOTS_DIR,
153
+ ) -> Path:
154
+ """Histogram of score distributions."""
155
+ fig, ax = plt.subplots(figsize=(10, 5))
156
+ bins = np.linspace(0, 1, 21)
157
+ ax.hist([r["score"] for r in baseline_results], bins=bins, alpha=0.6,
158
+ label="Baseline", color="#8B0000", edgecolor="black")
159
+ ax.hist([r["score"] for r in trained_results], bins=bins, alpha=0.6,
160
+ label="GRPO-trained", color="#D4AF37", edgecolor="black")
161
+ ax.set_xlabel("Grader Score", fontsize=12)
162
+ ax.set_ylabel("Frequency", fontsize=12)
163
+ ax.set_title("Score Distribution - Baseline vs. GRPO-Trained", fontsize=14, fontweight="bold")
164
+ ax.legend(fontsize=11)
165
+ ax.grid(axis="y", alpha=0.3)
166
+ plt.tight_layout()
167
+ path = out_dir / "score_distribution.png"
168
+ fig.savefig(path, dpi=150, bbox_inches="tight")
169
+ plt.close(fig)
170
+ print(f"Saved distribution plot -> {path}")
171
+ return path
172
+
173
+
174
+ def save_eval_results(
175
+ baseline_results: List[Dict[str, Any]],
176
+ trained_results: List[Dict[str, Any]],
177
+ out_dir: Path = PLOTS_DIR,
178
+ ) -> Path:
179
+ """Save evaluation results as JSON."""
180
+ data = {
181
+ "baseline": baseline_results,
182
+ "trained": trained_results,
183
+ "summary": {
184
+ "baseline_mean": float(np.mean([r["score"] for r in baseline_results])) if baseline_results else 0,
185
+ "trained_mean": float(np.mean([r["score"] for r in trained_results])) if trained_results else 0,
186
+ "improvement": float(np.mean([r["score"] for r in trained_results]) - np.mean([r["score"] for r in baseline_results])) if baseline_results and trained_results else 0,
187
+ },
188
+ }
189
+ path = out_dir / "eval_results.json"
190
+ path.write_text(json.dumps(data, indent=2))
191
+ print(f"Saved eval results -> {path}")
192
+ return path
training/reward.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reward functions for GRPO training.
3
+
4
+ Uses the real domain graders from the Agentic RAG Gym so the RL signal
5
+ matches the actual evaluation rubric — not a proxy.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ from datetime import datetime, timezone
12
+ from typing import Any, Dict, List
13
+
14
+ from domains.aerospace.graders import GRADER_REGISTRY
15
+ from rag_master.models import EpisodeState, StepRecord, Trajectory
16
+ from rag_master.rewards import _SCORE_MIN
17
+
18
+
19
+ def _make_dummy_state(task_id: str, answer: str) -> EpisodeState:
20
+ """Minimal EpisodeState for offline grading."""
21
+ from domains.aerospace.config import AerospaceDomainConfig
22
+
23
+ domain = AerospaceDomainConfig()
24
+ tasks = {t.task_id: t for t in domain.get_tasks()}
25
+ task = tasks.get(task_id)
26
+ if task is None:
27
+ raise ValueError(f"Unknown task_id: {task_id}")
28
+
29
+ return EpisodeState(
30
+ episode_id="grpo-eval",
31
+ task=task,
32
+ current_step=5,
33
+ query_history=["query"],
34
+ retrieved_docs=[],
35
+ agent_messages=[],
36
+ generated_answer=answer,
37
+ intermediate_rewards=[0.5] * 5,
38
+ done=True,
39
+ info={},
40
+ )
41
+
42
+
43
+ def _make_dummy_trajectory(task_id: str) -> Trajectory:
44
+ """Minimal trajectory with a good action sequence for process scoring."""
45
+ now = datetime.now(timezone.utc)
46
+ steps = [
47
+ StepRecord(step_index=0, action_type="plan", action_payload={},
48
+ observation_summary="planned", intermediate_reward=0.5,
49
+ reasoning_trace="Planning approach.", timestamp=now),
50
+ StepRecord(step_index=1, action_type="retrieve", action_payload={},
51
+ observation_summary="retrieved", intermediate_reward=0.6,
52
+ reasoning_trace="Retrieving.", timestamp=now),
53
+ StepRecord(step_index=2, action_type="reason", action_payload={},
54
+ observation_summary="reasoned", intermediate_reward=0.5,
55
+ reasoning_trace="Analyzing because data is relevant.", timestamp=now),
56
+ StepRecord(step_index=3, action_type="answer", action_payload={},
57
+ observation_summary="answered", intermediate_reward=0.6,
58
+ reasoning_trace="Final answer.", timestamp=now),
59
+ StepRecord(step_index=4, action_type="verify", action_payload={},
60
+ observation_summary="verified", intermediate_reward=0.5,
61
+ reasoning_trace="Verifying.", timestamp=now),
62
+ ]
63
+ return Trajectory(
64
+ episode_id="grpo-eval", task_id=task_id, steps=steps,
65
+ total_reward=0.0, final_score=0.0, completed=True, metadata={},
66
+ )
67
+
68
+
69
+ def grade_answer_sync(task_id: str, answer: str) -> float:
70
+ """Grade a single answer using the domain grader (synchronous)."""
71
+ grader_cls = GRADER_REGISTRY.get(task_id)
72
+ if grader_cls is None:
73
+ return float(_SCORE_MIN)
74
+ grader = grader_cls()
75
+ state = _make_dummy_state(task_id, answer)
76
+ trajectory = _make_dummy_trajectory(task_id)
77
+
78
+ try:
79
+ loop = asyncio.get_running_loop()
80
+ except RuntimeError:
81
+ loop = None
82
+
83
+ if loop and loop.is_running():
84
+ import concurrent.futures
85
+ with concurrent.futures.ThreadPoolExecutor() as pool:
86
+ return pool.submit(asyncio.run, grader.grade(state, trajectory)).result()
87
+ return asyncio.run(grader.grade(state, trajectory))