Spaces:
Sleeping
Sleeping
feat: GRPO fine-tuning pipeline + trained model for aerospace RAG
Browse files- Add Jupyter notebook (agentic-rag-for-aerospace-research.ipynb) with
full GRPO training pipeline: dataset collection, baseline eval, training,
post-training eval, plots, and HF Hub push
- Add training/ package (config, dataset, reward, evaluate modules)
- Add train.py standalone training script
- Add training plots: training_curves.png, baseline_vs_trained.png,
score_distribution.png
- GRPO-trained Qwen2.5-0.5B with LoRA (r=16, alpha=32)
- Baseline: 0.558 -> Trained: 0.586 (+0.028 improvement)
- Model pushed to williyam/agentic-rag-aerospace-grpo on HF Hub
- Update README with training section, results, and plot embeds
- .gitignore +4 -0
- README.md +78 -0
- agentic-rag-for-aerospace-research.ipynb +3 -0
- plots/baseline_vs_trained.png +3 -0
- plots/eval_results.json +84 -0
- plots/score_distribution.png +3 -0
- plots/training_curves.png +3 -0
- train.py +227 -0
- training/__init__.py +1 -0
- training/config.py +40 -0
- training/dataset.py +118 -0
- training/evaluate.py +192 -0
- training/reward.py +87 -0
.gitignore
CHANGED
|
@@ -25,3 +25,7 @@ data/uploads/
|
|
| 25 |
.DS_Store
|
| 26 |
Thumbs.db
|
| 27 |
node_modules/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
.DS_Store
|
| 26 |
Thumbs.db
|
| 27 |
node_modules/
|
| 28 |
+
|
| 29 |
+
# Training checkpoints (large files)
|
| 30 |
+
checkpoints/
|
| 31 |
+
.venv-1/
|
README.md
CHANGED
|
@@ -291,6 +291,80 @@ Update `server/app.py` to use your domain config instead of `AerospaceDomainConf
|
|
| 291 |
|
| 292 |
---
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
## Testing
|
| 295 |
|
| 296 |
```bash
|
|
@@ -344,9 +418,13 @@ agentic-rag-gym/
|
|
| 344 |
├── server/ # FastAPI + Gradio server
|
| 345 |
├── domains/aerospace/ # Aerospace research domain
|
| 346 |
├── domains/legal_research/ # Legal research domain (stub)
|
|
|
|
| 347 |
├── tests/ # Unit & integration tests (102+)
|
| 348 |
├── .github/workflows/ # CI pipeline
|
| 349 |
├── documents/ # Architecture & design docs
|
|
|
|
|
|
|
|
|
|
| 350 |
├── inference.py # Baseline inference script
|
| 351 |
├── openenv.yaml # OpenEnv specification
|
| 352 |
├── Dockerfile # Container definition
|
|
|
|
| 291 |
|
| 292 |
---
|
| 293 |
|
| 294 |
+
## GRPO Fine-Tuning (Reinforcement Learning)
|
| 295 |
+
|
| 296 |
+
We fine-tune **Qwen2.5-0.5B-Instruct** using **Group Relative Policy Optimization (GRPO)** from TRL,
|
| 297 |
+
with LoRA adapters and the **real domain graders** as the reward signal — no proxy rewards.
|
| 298 |
+
|
| 299 |
+
### Training Results
|
| 300 |
+
|
| 301 |
+
| Metric | Baseline | GRPO-Trained | Improvement |
|
| 302 |
+
|--------|----------|-------------|-------------|
|
| 303 |
+
| **Mean Score** | 0.5580 | 0.5860 | **+0.0280** |
|
| 304 |
+
| Propulsion Comparison | 0.508 | 0.562 | +0.053 |
|
| 305 |
+
| Debris Mitigation | 0.633 | 0.689 | +0.056 |
|
| 306 |
+
| Hypersonic Vehicle | 0.482 | 0.521 | +0.039 |
|
| 307 |
+
| Mars EDL | 0.574 | 0.568 | -0.006 |
|
| 308 |
+
| Life Support | 0.592 | 0.590 | -0.002 |
|
| 309 |
+
|
| 310 |
+
### Training Curves
|
| 311 |
+
|
| 312 |
+

|
| 313 |
+
|
| 314 |
+
### Baseline vs. GRPO-Trained
|
| 315 |
+
|
| 316 |
+

|
| 317 |
+
|
| 318 |
+
### Score Distribution
|
| 319 |
+
|
| 320 |
+

|
| 321 |
+
|
| 322 |
+
### Run Training (Notebook)
|
| 323 |
+
|
| 324 |
+
The primary training interface is the Jupyter notebook:
|
| 325 |
+
|
| 326 |
+
```bash
|
| 327 |
+
jupyter notebook agentic-rag-for-aerospace-research.ipynb
|
| 328 |
+
```
|
| 329 |
+
|
| 330 |
+
### Run Training (Script)
|
| 331 |
+
|
| 332 |
+
For headless/CI environments:
|
| 333 |
+
|
| 334 |
+
```bash
|
| 335 |
+
python train.py
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
### Configuration
|
| 339 |
+
|
| 340 |
+
| Parameter | Value |
|
| 341 |
+
|-----------|-------|
|
| 342 |
+
| Base Model | `Qwen/Qwen2.5-0.5B-Instruct` |
|
| 343 |
+
| Method | GRPO (Group Relative Policy Optimization) |
|
| 344 |
+
| LoRA | r=16, α=32, targets=q/k/v/o_proj |
|
| 345 |
+
| Optimizer | AdamW (torch) |
|
| 346 |
+
| Learning Rate | 5e-6 |
|
| 347 |
+
| Epochs | 2 |
|
| 348 |
+
| Group Size (G) | 4 |
|
| 349 |
+
| Max Completion | 512 tokens |
|
| 350 |
+
| Hardware | Apple M1 Pro (MPS) |
|
| 351 |
+
| Training Time | ~116 min |
|
| 352 |
+
|
| 353 |
+
### Fine-Tuned Model
|
| 354 |
+
|
| 355 |
+
The GRPO-trained model is available on Hugging Face:
|
| 356 |
+
**[williyam/agentic-rag-aerospace-grpo](https://huggingface.co/williyam/agentic-rag-aerospace-grpo)**
|
| 357 |
+
|
| 358 |
+
```python
|
| 359 |
+
from peft import AutoPeftModelForCausalLM
|
| 360 |
+
from transformers import AutoTokenizer
|
| 361 |
+
|
| 362 |
+
model = AutoPeftModelForCausalLM.from_pretrained("williyam/agentic-rag-aerospace-grpo")
|
| 363 |
+
tokenizer = AutoTokenizer.from_pretrained("williyam/agentic-rag-aerospace-grpo")
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
---
|
| 367 |
+
|
| 368 |
## Testing
|
| 369 |
|
| 370 |
```bash
|
|
|
|
| 418 |
├── server/ # FastAPI + Gradio server
|
| 419 |
├── domains/aerospace/ # Aerospace research domain
|
| 420 |
├── domains/legal_research/ # Legal research domain (stub)
|
| 421 |
+
├── training/ # GRPO training package
|
| 422 |
├── tests/ # Unit & integration tests (102+)
|
| 423 |
├── .github/workflows/ # CI pipeline
|
| 424 |
├── documents/ # Architecture & design docs
|
| 425 |
+
├── plots/ # Training curves & evaluation plots
|
| 426 |
+
├── agentic-rag-for-aerospace-research.ipynb # GRPO training notebook
|
| 427 |
+
├── train.py # Standalone training script
|
| 428 |
├── inference.py # Baseline inference script
|
| 429 |
├── openenv.yaml # OpenEnv specification
|
| 430 |
├── Dockerfile # Container definition
|
agentic-rag-for-aerospace-research.ipynb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e8b90675048b0211601e7e8283ad1a2c5e41ff8f6fe5e0f5620b1be16d30993e
|
| 3 |
+
size 10531022
|
plots/baseline_vs_trained.png
ADDED
|
Git LFS Details
|
plots/eval_results.json
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 3 |
+
"finetuned": "williyam/agentic-rag-aerospace-grpo",
|
| 4 |
+
"training_time_min": 116.41443806489309,
|
| 5 |
+
"baseline": [
|
| 6 |
+
{
|
| 7 |
+
"task_id": "aero_hard_hypersonic_vehicle",
|
| 8 |
+
"task_name": "Reusable Hypersonic Space Access Vehicle",
|
| 9 |
+
"difficulty": "hard",
|
| 10 |
+
"answer": "### Conceptual Design Overview\n\n#### Introduction\nReusability of a hypersonic space access vehicle (SAV) with a combined-cycle propulsion (CCP) system is crucial for reducing operational costs, increasing payload capacity, and enhancing safety. The proposed SAV utilizes a combination of scramjet, turbine, and rocket engines, designed to achieve high-speed, long-duration hypersonic travel. The propulsion mode transitions between scramjets and rockets, allowing for efficient use of both technologi",
|
| 11 |
+
"score": 0.48219047619047617
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"task_id": "aero_easy_propulsion_comparison",
|
| 15 |
+
"task_name": "Compare Propulsion Technologies",
|
| 16 |
+
"difficulty": "easy",
|
| 17 |
+
"answer": "### Comparison of Propulsion Technologies for a Mars Transit Mission\n\n#### Specific Impulse and Thrust Levels\n\n**Ion Propulsion (IP):**\n- **Specific Impulse:** 4,190 seconds (NASA)\n- **Thrust Level:** Very Low (0.5 N for NEXT)\n\n**Nuclear Thermal Propulsion (NTP):**\n- **Specific Impulse:** 850-1,000 seconds (NASA)\n- **Thrust Level:** Extremely High (approximately 1,000 seconds)\n\n**Comparison:**\n- **Specific Impulse:** IP > NTP\n- **Thrust Level:** IP < NTP\n\n**Impact on Mission Duration:**\n- **IP:*",
|
| 18 |
+
"score": 0.5082857142857143
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"task_id": "aero_medium_mars_edl",
|
| 22 |
+
"task_name": "Mars EDL Architecture Design",
|
| 23 |
+
"difficulty": "medium",
|
| 24 |
+
"answer": "### Overview of the EDL Architecture Design\n\n#### 1. **Entry Phase**\n - **Aerocapture**: \n - **Objective**: To capture the spacecraft into the Martian atmosphere.\n - **Data**: The spacecraft must be designed to achieve a high velocity relative to the Martian atmosphere to avoid atmospheric drag. This typically involves using a combination of parachutes, airbags, and other aerodynamic devices.\n - **Quantitative Evidence**: The aerocapture trajectory should be optimized to minimize e",
|
| 25 |
+
"score": 0.5742857142857143
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"task_id": "aero_medium_life_support",
|
| 29 |
+
"task_name": "Deep Space Life Support Design",
|
| 30 |
+
"difficulty": "medium",
|
| 31 |
+
"answer": "### Deep Space Life Support Design for a 6-Person Crew on a 2.5-Year Mars Mission\n\n#### Overview\nThis design integrates physicochemical and bioregenerative systems to maximize closure ratios and address various aspects of life support on a 2.5-year Mars mission. The goal is to ensure the survival and health of the crew while minimizing resource consumption and environmental impact.\n\n#### System Components\n1. **Physicochemical Life Support (PLS)**\n - **Thermophilic Bacteria for Waste Liquefacti",
|
| 32 |
+
"score": 0.5917142857142857
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"task_id": "aero_easy_debris_mitigation",
|
| 36 |
+
"task_name": "Space Debris Mitigation Strategies",
|
| 37 |
+
"difficulty": "easy",
|
| 38 |
+
"answer": "### Space Debris Mitigation Strategies\n\n#### Current Challenges and Active Debris Removal Technologies\n\n**Current Challenges:**\n1. **Kessler Syndrome Risk:** Over 36,500 objects >10 cm in LEO reach critical densities, leading to significant collisions that threaten orbit stability.\n2. **Tracking Capabilities:** Limited ground-based and orbital tracking capabilities, especially for larger objects like satellites and asteroids.\n3. **Proposed Removal Methods:**\n - **Electrodynamic Tethers:** Gene",
|
| 39 |
+
"score": 0.6334285714285715
|
| 40 |
+
}
|
| 41 |
+
],
|
| 42 |
+
"trained": [
|
| 43 |
+
{
|
| 44 |
+
"task_id": "aero_hard_hypersonic_vehicle",
|
| 45 |
+
"task_name": "Reusable Hypersonic Space Access Vehicle",
|
| 46 |
+
"difficulty": "hard",
|
| 47 |
+
"answer": "### Conceptual Design Overview\n\n#### Introduction\nReusability of a hypersonic space access vehicle (SAV) with a combined-cycle propulsion (CCP) system is crucial for reducing operational costs, increasing payload capacity, and enhancing safety. The proposed SAV utilizes a combination of scramjet, turbine, and rocket propulsion modes, aiming to achieve high-speed, long-duration flights within the hypersonic regime. This approach addresses key aerothermodynamic challenges such as shock-boundary la",
|
| 48 |
+
"score": 0.5214285714285715
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"task_id": "aero_easy_propulsion_comparison",
|
| 52 |
+
"task_name": "Compare Propulsion Technologies",
|
| 53 |
+
"difficulty": "easy",
|
| 54 |
+
"answer": "### Comparison of Propulsion Technologies for a Mars Transit Mission\n\n#### Specific Impulse\n**Ion Propulsion (IP):**\n- **Specific Impulse:** 4,190 seconds\n- **Thrust Level:** Extremely Low (0.5 N)\n- **Impact on Mission Duration:** Minimal impact on mission duration as IP is designed for short-duration missions.\n- **Technology Readiness:** High due to its ability to achieve high specific impulses without significant technological hurdles.\n\n**Nuclear Thermal Propulsion (NTP):**\n- **Specific Impuls",
|
| 55 |
+
"score": 0.5616190476190477
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"task_id": "aero_medium_mars_edl",
|
| 59 |
+
"task_name": "Mars EDL Architecture Design",
|
| 60 |
+
"difficulty": "medium",
|
| 61 |
+
"answer": "### Overview of the EDL Architecture Design\n\n#### Introduction\nThe EDL architecture design addresses the challenges posed by the Mars atmospheric environment, including the high peak heat flux and the need for effective aerodynamic deceleration of heavy payloads. This design integrates multiple phases into a single, cohesive approach to ensure efficient and safe mission execution. The proposed architecture includes:\n\n1. **Aerocapture**: The first phase involves capturing the payload using aerobr",
|
| 62 |
+
"score": 0.5678095238095239
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"task_id": "aero_medium_life_support",
|
| 66 |
+
"task_name": "Deep Space Life Support Design",
|
| 67 |
+
"difficulty": "medium",
|
| 68 |
+
"answer": "### Overview of the Integrated Life Support System Design\n\n#### Introduction\nDeep space exploration presents unique challenges due to the extreme conditions encountered, including low gravity, radiation exposure, and the presence of hazardous materials like water and carbon dioxide. To ensure long-term survival and minimize environmental impact, a sophisticated life support system must be designed. This design integrates various biological processes and energy sources to maximize closure ratios ",
|
| 69 |
+
"score": 0.5897142857142857
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"task_id": "aero_easy_debris_mitigation",
|
| 73 |
+
"task_name": "Space Debris Mitigation Strategies",
|
| 74 |
+
"difficulty": "easy",
|
| 75 |
+
"answer": "**Task: Space Debris Mitigation Strategies**\n\n**Current Challenges and Evaluation of Active Debris Removal Technologies**\n\nActive debris removal (ADR) technologies have been developed to mitigate the effects of space debris on orbiting satellites and other spacecraft. These technologies aim to prevent collisions that could lead to catastrophic damage to spacecraft and disrupt satellite communications. However, the effectiveness of these technologies varies widely depending on several factors suc",
|
| 76 |
+
"score": 0.6894285714285716
|
| 77 |
+
}
|
| 78 |
+
],
|
| 79 |
+
"summary": {
|
| 80 |
+
"baseline_mean": 0.5579809523809524,
|
| 81 |
+
"trained_mean": 0.5860000000000001,
|
| 82 |
+
"improvement": 0.02801904761904772
|
| 83 |
+
}
|
| 84 |
+
}
|
plots/score_distribution.png
ADDED
|
Git LFS Details
|
plots/training_curves.png
ADDED
|
Git LFS Details
|
train.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
train.py — GRPO Fine-Tuning for Agentic RAG Gym (Aerospace Domain)
|
| 4 |
+
====================================================================
|
| 5 |
+
|
| 6 |
+
End-to-end training script that:
|
| 7 |
+
1. Connects to the live gym environment to collect prompts
|
| 8 |
+
2. Loads Qwen2.5-0.5B-Instruct with LoRA
|
| 9 |
+
3. Trains with GRPO (Group Relative Policy Optimization) via TRL
|
| 10 |
+
4. Evaluates baseline vs. trained model with domain graders
|
| 11 |
+
5. Generates publication-quality plots
|
| 12 |
+
6. Pushes the fine-tuned model to Hugging Face Hub
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Start the environment first:
|
| 16 |
+
python main.py &
|
| 17 |
+
|
| 18 |
+
# Then run training:
|
| 19 |
+
python train.py
|
| 20 |
+
|
| 21 |
+
Environment variables (loaded from .env):
|
| 22 |
+
HF_TOKEN Hugging Face token (for model push)
|
| 23 |
+
HF_USERNAME Hugging Face username (default: williyam)
|
| 24 |
+
ENV_URL Gym environment URL (default: http://localhost:7860)
|
| 25 |
+
BASE_MODEL_ID Base model (default: Qwen/Qwen2.5-0.5B-Instruct)
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import torch
|
| 35 |
+
from dotenv import load_dotenv
|
| 36 |
+
from peft import LoraConfig, TaskType
|
| 37 |
+
|
| 38 |
+
load_dotenv()
|
| 39 |
+
|
| 40 |
+
from training.config import (
|
| 41 |
+
BASE_MODEL_ID,
|
| 42 |
+
CHECKPOINTS_DIR,
|
| 43 |
+
FINETUNED_MODEL_ID,
|
| 44 |
+
HF_TOKEN,
|
| 45 |
+
PLOTS_DIR,
|
| 46 |
+
)
|
| 47 |
+
from training.dataset import SYSTEM_PROMPT, build_grpo_dataset
|
| 48 |
+
from training.evaluate import (
|
| 49 |
+
evaluate_model_on_tasks,
|
| 50 |
+
plot_baseline_vs_trained,
|
| 51 |
+
plot_reward_distribution,
|
| 52 |
+
plot_training_curves,
|
| 53 |
+
save_eval_results,
|
| 54 |
+
)
|
| 55 |
+
from training.reward import grade_answer_sync
|
| 56 |
+
|
| 57 |
+
# ── Device ─────────────────────────────────────────────────────────────
|
| 58 |
+
if torch.backends.mps.is_available():
|
| 59 |
+
DEVICE = "mps"
|
| 60 |
+
elif torch.cuda.is_available():
|
| 61 |
+
DEVICE = "cuda"
|
| 62 |
+
else:
|
| 63 |
+
DEVICE = "cpu"
|
| 64 |
+
print(f"Device: {DEVICE} | PyTorch {torch.__version__}")
|
| 65 |
+
|
| 66 |
+
# ── Build dataset ──────────────────────────────────────────────────────
|
| 67 |
+
print("\n[1/6] Collecting prompts from environment...")
|
| 68 |
+
dataset = build_grpo_dataset(num_per_task=8, seed=42)
|
| 69 |
+
|
| 70 |
+
task_ids_map = {}
|
| 71 |
+
for row in dataset:
|
| 72 |
+
task_ids_map[row["task_id"]] = row["task_name"]
|
| 73 |
+
|
| 74 |
+
# Format for TRL: prompt column must be list of message dicts
|
| 75 |
+
def format_for_trl(example):
|
| 76 |
+
return {
|
| 77 |
+
"prompt": [
|
| 78 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 79 |
+
{"role": "user", "content": example["prompt"]},
|
| 80 |
+
],
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
train_dataset = dataset.map(format_for_trl, remove_columns=["task_name", "difficulty"])
|
| 84 |
+
|
| 85 |
+
# ── Load model ─────────────────────────────────────────────────────────
|
| 86 |
+
print(f"\n[2/6] Loading model: {BASE_MODEL_ID}")
|
| 87 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 88 |
+
|
| 89 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True, padding_side="left")
|
| 90 |
+
if tokenizer.pad_token is None:
|
| 91 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 92 |
+
|
| 93 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 94 |
+
BASE_MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
peft_config = LoraConfig(
|
| 98 |
+
task_type=TaskType.CAUSAL_LM,
|
| 99 |
+
r=16, lora_alpha=32, lora_dropout=0.05,
|
| 100 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 101 |
+
bias="none",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# ── Baseline evaluation ───────────────────────────────────────────────
|
| 105 |
+
print("\n[3/6] Evaluating baseline (before training)...")
|
| 106 |
+
from peft import get_peft_model
|
| 107 |
+
eval_model = get_peft_model(model.to(DEVICE), peft_config)
|
| 108 |
+
eval_model.print_trainable_parameters()
|
| 109 |
+
|
| 110 |
+
eval_prompts = []
|
| 111 |
+
seen = set()
|
| 112 |
+
for row in dataset:
|
| 113 |
+
if row["task_id"] not in seen:
|
| 114 |
+
eval_prompts.append(row)
|
| 115 |
+
seen.add(row["task_id"])
|
| 116 |
+
|
| 117 |
+
baseline_results = evaluate_model_on_tasks(
|
| 118 |
+
eval_model, tokenizer, eval_prompts, max_new_tokens=512, temperature=0.1,
|
| 119 |
+
)
|
| 120 |
+
baseline_mean = np.mean([r["score"] for r in baseline_results])
|
| 121 |
+
print(f" Baseline mean score: {baseline_mean:.4f}")
|
| 122 |
+
del eval_model
|
| 123 |
+
model = model.to("cpu")
|
| 124 |
+
|
| 125 |
+
# ── Reward function ────────────────────────────────────────────────────
|
| 126 |
+
def reward_fn(completions, **kwargs):
|
| 127 |
+
"""Score completions using domain graders."""
|
| 128 |
+
rewards = []
|
| 129 |
+
prompts = kwargs.get("prompts", [])
|
| 130 |
+
for i, completion in enumerate(completions):
|
| 131 |
+
text = completion[0]["content"] if isinstance(completion, list) else str(completion)
|
| 132 |
+
text = text.strip()
|
| 133 |
+
if len(text) < 10:
|
| 134 |
+
rewards.append(0.01)
|
| 135 |
+
continue
|
| 136 |
+
task_id = None
|
| 137 |
+
if i < len(prompts):
|
| 138 |
+
p = prompts[i]
|
| 139 |
+
if isinstance(p, list):
|
| 140 |
+
p = " ".join(m.get("content", "") for m in p)
|
| 141 |
+
for tid, name in task_ids_map.items():
|
| 142 |
+
if name in str(p):
|
| 143 |
+
task_id = tid
|
| 144 |
+
break
|
| 145 |
+
if task_id is None:
|
| 146 |
+
task_id = list(task_ids_map.keys())[0]
|
| 147 |
+
try:
|
| 148 |
+
rewards.append(float(grade_answer_sync(task_id, text)))
|
| 149 |
+
except Exception:
|
| 150 |
+
rewards.append(0.01)
|
| 151 |
+
return rewards
|
| 152 |
+
|
| 153 |
+
# ── GRPO Training ──────────────────────────────────────────────────────
|
| 154 |
+
print("\n[4/6] Starting GRPO training...")
|
| 155 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 156 |
+
|
| 157 |
+
training_args = GRPOConfig(
|
| 158 |
+
output_dir=str(CHECKPOINTS_DIR / "grpo"),
|
| 159 |
+
num_train_epochs=2,
|
| 160 |
+
per_device_train_batch_size=1,
|
| 161 |
+
gradient_accumulation_steps=4,
|
| 162 |
+
learning_rate=5e-6,
|
| 163 |
+
warmup_ratio=0.1,
|
| 164 |
+
max_grad_norm=1.0,
|
| 165 |
+
logging_steps=1,
|
| 166 |
+
save_steps=50,
|
| 167 |
+
save_total_limit=2,
|
| 168 |
+
bf16=False,
|
| 169 |
+
fp16=False,
|
| 170 |
+
seed=42,
|
| 171 |
+
remove_unused_columns=False,
|
| 172 |
+
num_generations=4,
|
| 173 |
+
max_completion_length=512,
|
| 174 |
+
temperature=0.7,
|
| 175 |
+
use_vllm=False,
|
| 176 |
+
report_to="none",
|
| 177 |
+
optim="adamw_torch",
|
| 178 |
+
gradient_checkpointing=True,
|
| 179 |
+
log_completions=True,
|
| 180 |
+
num_completions_to_print=1,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
trainer = GRPOTrainer(
|
| 184 |
+
model=BASE_MODEL_ID,
|
| 185 |
+
reward_funcs=reward_fn,
|
| 186 |
+
args=training_args,
|
| 187 |
+
train_dataset=train_dataset,
|
| 188 |
+
peft_config=peft_config,
|
| 189 |
+
processing_class=tokenizer,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
t0 = time.time()
|
| 193 |
+
trainer.train()
|
| 194 |
+
elapsed = time.time() - t0
|
| 195 |
+
print(f"\nTraining completed in {elapsed/60:.1f} min")
|
| 196 |
+
|
| 197 |
+
# ── Post-training evaluation ──────────────────────────────────────────
|
| 198 |
+
print("\n[5/6] Evaluating trained model...")
|
| 199 |
+
trained_results = evaluate_model_on_tasks(
|
| 200 |
+
trainer.model, tokenizer, eval_prompts, max_new_tokens=512, temperature=0.1,
|
| 201 |
+
)
|
| 202 |
+
trained_mean = np.mean([r["score"] for r in trained_results])
|
| 203 |
+
print(f" Trained mean score: {trained_mean:.4f}")
|
| 204 |
+
print(f" Improvement: {trained_mean - baseline_mean:+.4f}")
|
| 205 |
+
|
| 206 |
+
# ── Plots + save ──────────────────────────────────────────────────────
|
| 207 |
+
print("\n[6/6] Generating plots and saving...")
|
| 208 |
+
log_history = trainer.state.log_history if hasattr(trainer, "state") else []
|
| 209 |
+
plot_training_curves(log_history)
|
| 210 |
+
plot_baseline_vs_trained(baseline_results, trained_results)
|
| 211 |
+
plot_reward_distribution(baseline_results, trained_results)
|
| 212 |
+
save_eval_results(baseline_results, trained_results)
|
| 213 |
+
|
| 214 |
+
# Push
|
| 215 |
+
if HF_TOKEN:
|
| 216 |
+
print(f"\nPushing to HF Hub: {FINETUNED_MODEL_ID}")
|
| 217 |
+
trainer.model.push_to_hub(FINETUNED_MODEL_ID, token=HF_TOKEN, private=False)
|
| 218 |
+
tokenizer.push_to_hub(FINETUNED_MODEL_ID, token=HF_TOKEN, private=False)
|
| 219 |
+
print("Model pushed successfully")
|
| 220 |
+
|
| 221 |
+
print(f"\n{'='*50}")
|
| 222 |
+
print(f" Base model: {BASE_MODEL_ID}")
|
| 223 |
+
print(f" Training time: {elapsed/60:.1f} min")
|
| 224 |
+
print(f" Baseline score: {baseline_mean:.4f}")
|
| 225 |
+
print(f" Trained score: {trained_mean:.4f}")
|
| 226 |
+
print(f" Delta: {trained_mean - baseline_mean:+.4f}")
|
| 227 |
+
print(f"{'='*50}")
|
training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training utilities for Agentic RAG Gym GRPO fine-tuning."""
|
training/config.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for GRPO training of Agentic RAG Gym.
|
| 3 |
+
All secrets and tunables are loaded from environment variables / .env file.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# Paths
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 19 |
+
PLOTS_DIR = PROJECT_ROOT / "plots"
|
| 20 |
+
PLOTS_DIR.mkdir(exist_ok=True)
|
| 21 |
+
CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
|
| 22 |
+
CHECKPOINTS_DIR.mkdir(exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Environment / secrets
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
|
| 28 |
+
HF_USERNAME: str = os.getenv("HF_USERNAME", "williyam")
|
| 29 |
+
|
| 30 |
+
# Environment server (our FastAPI gym)
|
| 31 |
+
ENV_URL: str = os.getenv("ENV_URL", "http://localhost:7860")
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Model configuration
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
BASE_MODEL_ID: str = os.getenv("BASE_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 37 |
+
FINETUNED_MODEL_ID: str = os.getenv(
|
| 38 |
+
"FINETUNED_MODEL_ID",
|
| 39 |
+
f"{HF_USERNAME}/agentic-rag-aerospace-grpo",
|
| 40 |
+
)
|
training/dataset.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset builder for GRPO training.
|
| 3 |
+
|
| 4 |
+
Connects to the live Agentic RAG Gym environment to build training prompts.
|
| 5 |
+
Each prompt = system instruction + task description + retrieved documents.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import asyncio
|
| 11 |
+
import re
|
| 12 |
+
from typing import Any, Dict, List
|
| 13 |
+
|
| 14 |
+
import httpx
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
|
| 17 |
+
from training.config import ENV_URL
|
| 18 |
+
|
| 19 |
+
SYSTEM_PROMPT = (
|
| 20 |
+
"You are an expert aerospace research analyst with deep knowledge of "
|
| 21 |
+
"propulsion systems, orbital mechanics, materials science, thermal protection, "
|
| 22 |
+
"life support systems, and space mission design. When analyzing aerospace topics:\n"
|
| 23 |
+
"- Cite specific data points and numerical values from provided documents\n"
|
| 24 |
+
"- Structure your analysis with clear sections\n"
|
| 25 |
+
"- Compare alternatives with quantitative evidence\n"
|
| 26 |
+
"- Provide actionable recommendations grounded in engineering constraints\n"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _format_prompt(task: Dict[str, Any], docs: List[Dict[str, Any]]) -> str:
|
| 31 |
+
"""Build a user message from a task + retrieved docs."""
|
| 32 |
+
doc_text = ""
|
| 33 |
+
for i, doc in enumerate(docs, 1):
|
| 34 |
+
src = doc.get("source", "unknown")
|
| 35 |
+
doc_text += f"\n[Document {i} -- {src}]\n{doc['content']}\n"
|
| 36 |
+
|
| 37 |
+
return (
|
| 38 |
+
f"## Task: {task['name']}\n\n"
|
| 39 |
+
f"{task['description']}\n\n"
|
| 40 |
+
f"### Retrieved Reference Documents\n{doc_text}\n"
|
| 41 |
+
"### Instructions\n"
|
| 42 |
+
"Provide a comprehensive, well-structured answer to the task above. "
|
| 43 |
+
"Cite specific data from the reference documents. "
|
| 44 |
+
"Include quantitative evidence and clear recommendations."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def _fetch_tasks(client: httpx.AsyncClient) -> List[Dict[str, Any]]:
|
| 49 |
+
resp = await client.get(f"{ENV_URL}/tasks")
|
| 50 |
+
resp.raise_for_status()
|
| 51 |
+
return resp.json()["tasks"]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _query_variants(task: Dict[str, Any]) -> List[str]:
|
| 55 |
+
"""Generate diverse retrieval queries from a task description."""
|
| 56 |
+
desc = task["description"]
|
| 57 |
+
sentences = [s.strip() for s in re.split(r'[.!?]+', desc) if len(s.strip()) > 20]
|
| 58 |
+
variants = [desc]
|
| 59 |
+
variants.extend(sentences[:4])
|
| 60 |
+
name_words = task["name"].lower()
|
| 61 |
+
variants.append(name_words)
|
| 62 |
+
return variants
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
async def _collect_one_prompt(
|
| 66 |
+
client: httpx.AsyncClient,
|
| 67 |
+
task: Dict[str, Any],
|
| 68 |
+
query: str,
|
| 69 |
+
) -> Dict[str, Any] | None:
|
| 70 |
+
"""Reset env, retrieve docs, build a prompt."""
|
| 71 |
+
resp = await client.post(f"{ENV_URL}/reset", json={"task_id": task["task_id"]})
|
| 72 |
+
if resp.status_code != 200:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
resp = await client.post(
|
| 76 |
+
f"{ENV_URL}/step", json={"type": "retrieve", "query": query}
|
| 77 |
+
)
|
| 78 |
+
if resp.status_code != 200:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
docs = resp.json()["observation"]["retrieved_docs"]
|
| 82 |
+
user_msg = _format_prompt(task, docs)
|
| 83 |
+
return {
|
| 84 |
+
"task_id": task["task_id"],
|
| 85 |
+
"task_name": task["name"],
|
| 86 |
+
"difficulty": task.get("difficulty", "easy"),
|
| 87 |
+
"prompt": user_msg,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
async def _build_dataset_async(num_per_task: int = 8) -> List[Dict[str, Any]]:
|
| 92 |
+
records: List[Dict[str, Any]] = []
|
| 93 |
+
async with httpx.AsyncClient(timeout=120.0) as client:
|
| 94 |
+
tasks = await _fetch_tasks(client)
|
| 95 |
+
for task in tasks:
|
| 96 |
+
variants = _query_variants(task)
|
| 97 |
+
for i in range(num_per_task):
|
| 98 |
+
query = variants[i % len(variants)]
|
| 99 |
+
rec = await _collect_one_prompt(client, task, query)
|
| 100 |
+
if rec:
|
| 101 |
+
records.append(rec)
|
| 102 |
+
return records
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def build_grpo_dataset(num_per_task: int = 8, seed: int = 42) -> Dataset:
|
| 106 |
+
"""
|
| 107 |
+
Build a HuggingFace Dataset of prompts for GRPO training.
|
| 108 |
+
|
| 109 |
+
Each row: prompt (str), task_id, task_name, difficulty.
|
| 110 |
+
Prompts are collected from the LIVE environment (reset + retrieve).
|
| 111 |
+
"""
|
| 112 |
+
records = asyncio.run(_build_dataset_async(num_per_task))
|
| 113 |
+
if not records:
|
| 114 |
+
raise RuntimeError(f"No prompts collected. Is the environment running at {ENV_URL}?")
|
| 115 |
+
ds = Dataset.from_list(records)
|
| 116 |
+
ds = ds.shuffle(seed=seed)
|
| 117 |
+
print(f"Built GRPO dataset: {len(ds)} prompts across {len(ds.unique('task_id'))} tasks")
|
| 118 |
+
return ds
|
training/evaluate.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation & plotting utilities for GRPO training.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
import matplotlib
|
| 13 |
+
matplotlib.use("Agg")
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from training.config import PLOTS_DIR
|
| 18 |
+
from training.reward import grade_answer_sync
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ── Model evaluation ────────────────────────────────────────────────────
|
| 22 |
+
|
| 23 |
+
def evaluate_model_on_tasks(
|
| 24 |
+
model,
|
| 25 |
+
tokenizer,
|
| 26 |
+
prompts: List[Dict[str, Any]],
|
| 27 |
+
max_new_tokens: int = 512,
|
| 28 |
+
temperature: float = 0.1,
|
| 29 |
+
) -> List[Dict[str, Any]]:
|
| 30 |
+
"""Generate answers for each prompt and grade them with domain graders."""
|
| 31 |
+
import torch
|
| 32 |
+
results: List[Dict[str, Any]] = []
|
| 33 |
+
device = next(model.parameters()).device
|
| 34 |
+
|
| 35 |
+
for item in prompts:
|
| 36 |
+
messages = [
|
| 37 |
+
{"role": "system", "content": "You are an expert aerospace research analyst. Provide comprehensive answers citing specific data."},
|
| 38 |
+
{"role": "user", "content": item["prompt"]},
|
| 39 |
+
]
|
| 40 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 41 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048)
|
| 42 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
output_ids = model.generate(
|
| 46 |
+
**inputs,
|
| 47 |
+
max_new_tokens=max_new_tokens,
|
| 48 |
+
temperature=max(temperature, 0.01),
|
| 49 |
+
do_sample=temperature > 0,
|
| 50 |
+
top_p=0.9,
|
| 51 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
generated = output_ids[0][inputs["input_ids"].shape[1]:]
|
| 55 |
+
answer = tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 56 |
+
score = grade_answer_sync(item["task_id"], answer)
|
| 57 |
+
|
| 58 |
+
results.append({
|
| 59 |
+
"task_id": item["task_id"],
|
| 60 |
+
"task_name": item["task_name"],
|
| 61 |
+
"difficulty": item["difficulty"],
|
| 62 |
+
"answer": answer[:500],
|
| 63 |
+
"score": score,
|
| 64 |
+
})
|
| 65 |
+
print(f" [{item['task_id']}] score={score:.3f} len={len(answer)}")
|
| 66 |
+
return results
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ── Plotting ─────────────────────────────────────────────────────────────
|
| 70 |
+
|
| 71 |
+
def plot_training_curves(log_history: List[Dict[str, Any]], out_dir: Path = PLOTS_DIR) -> Path:
|
| 72 |
+
"""Plot training loss and reward curves. Returns path to saved figure."""
|
| 73 |
+
steps = [e["step"] for e in log_history if "loss" in e]
|
| 74 |
+
losses = [e["loss"] for e in log_history if "loss" in e]
|
| 75 |
+
reward_steps = [e["step"] for e in log_history if "reward" in e or "reward/mean" in e]
|
| 76 |
+
rewards = [e.get("reward/mean", e.get("reward")) for e in log_history if "reward" in e or "reward/mean" in e]
|
| 77 |
+
|
| 78 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 79 |
+
|
| 80 |
+
ax = axes[0]
|
| 81 |
+
if steps and losses:
|
| 82 |
+
ax.plot(steps, losses, color="#D4AF37", linewidth=2)
|
| 83 |
+
ax.set_xlabel("Training Step", fontsize=12)
|
| 84 |
+
ax.set_ylabel("Loss", fontsize=12)
|
| 85 |
+
ax.set_title("GRPO Training Loss", fontsize=14, fontweight="bold")
|
| 86 |
+
ax.grid(True, alpha=0.3)
|
| 87 |
+
|
| 88 |
+
ax = axes[1]
|
| 89 |
+
if reward_steps and rewards:
|
| 90 |
+
ax.plot(reward_steps, rewards, color="#4CAF50", linewidth=2)
|
| 91 |
+
ax.set_xlabel("Training Step", fontsize=12)
|
| 92 |
+
ax.set_ylabel("Mean Reward (grader score)", fontsize=12)
|
| 93 |
+
ax.set_title("GRPO Mean Reward", fontsize=14, fontweight="bold")
|
| 94 |
+
ax.grid(True, alpha=0.3)
|
| 95 |
+
|
| 96 |
+
plt.tight_layout()
|
| 97 |
+
path = out_dir / "training_curves.png"
|
| 98 |
+
fig.savefig(path, dpi=150, bbox_inches="tight")
|
| 99 |
+
plt.close(fig)
|
| 100 |
+
print(f"Saved training curves -> {path}")
|
| 101 |
+
return path
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def plot_baseline_vs_trained(
|
| 105 |
+
baseline_results: List[Dict[str, Any]],
|
| 106 |
+
trained_results: List[Dict[str, Any]],
|
| 107 |
+
out_dir: Path = PLOTS_DIR,
|
| 108 |
+
) -> Path:
|
| 109 |
+
"""Bar chart comparing baseline vs trained scores per task."""
|
| 110 |
+
def _agg(results):
|
| 111 |
+
sums: Dict[str, List[float]] = defaultdict(list)
|
| 112 |
+
for r in results:
|
| 113 |
+
sums[r["task_id"]].append(r["score"])
|
| 114 |
+
return {k: float(np.mean(v)) for k, v in sums.items()}
|
| 115 |
+
|
| 116 |
+
baseline_agg = _agg(baseline_results)
|
| 117 |
+
trained_agg = _agg(trained_results)
|
| 118 |
+
tasks = sorted(set(baseline_agg) | set(trained_agg))
|
| 119 |
+
short_names = [t.replace("aero_", "").replace("_", " ").title() for t in tasks]
|
| 120 |
+
|
| 121 |
+
x = np.arange(len(tasks))
|
| 122 |
+
width = 0.35
|
| 123 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 124 |
+
bars1 = ax.bar(x - width / 2, [baseline_agg.get(t, 0) for t in tasks],
|
| 125 |
+
width, label="Baseline (untrained)", color="#8B0000", alpha=0.85, edgecolor="black")
|
| 126 |
+
bars2 = ax.bar(x + width / 2, [trained_agg.get(t, 0) for t in tasks],
|
| 127 |
+
width, label="GRPO-trained", color="#D4AF37", alpha=0.85, edgecolor="black")
|
| 128 |
+
|
| 129 |
+
ax.set_xlabel("Task", fontsize=12)
|
| 130 |
+
ax.set_ylabel("Grader Score (0-1)", fontsize=12)
|
| 131 |
+
ax.set_title("Baseline vs. GRPO-Trained Model - Task Scores", fontsize=14, fontweight="bold")
|
| 132 |
+
ax.set_xticks(x)
|
| 133 |
+
ax.set_xticklabels(short_names, rotation=20, ha="right", fontsize=10)
|
| 134 |
+
ax.legend(fontsize=11)
|
| 135 |
+
ax.set_ylim(0, 1.0)
|
| 136 |
+
ax.grid(axis="y", alpha=0.3)
|
| 137 |
+
for bar in list(bars1) + list(bars2):
|
| 138 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
|
| 139 |
+
f"{bar.get_height():.2f}", ha="center", fontsize=9)
|
| 140 |
+
|
| 141 |
+
plt.tight_layout()
|
| 142 |
+
path = out_dir / "baseline_vs_trained.png"
|
| 143 |
+
fig.savefig(path, dpi=150, bbox_inches="tight")
|
| 144 |
+
plt.close(fig)
|
| 145 |
+
print(f"Saved comparison chart -> {path}")
|
| 146 |
+
return path
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def plot_reward_distribution(
|
| 150 |
+
baseline_results: List[Dict[str, Any]],
|
| 151 |
+
trained_results: List[Dict[str, Any]],
|
| 152 |
+
out_dir: Path = PLOTS_DIR,
|
| 153 |
+
) -> Path:
|
| 154 |
+
"""Histogram of score distributions."""
|
| 155 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 156 |
+
bins = np.linspace(0, 1, 21)
|
| 157 |
+
ax.hist([r["score"] for r in baseline_results], bins=bins, alpha=0.6,
|
| 158 |
+
label="Baseline", color="#8B0000", edgecolor="black")
|
| 159 |
+
ax.hist([r["score"] for r in trained_results], bins=bins, alpha=0.6,
|
| 160 |
+
label="GRPO-trained", color="#D4AF37", edgecolor="black")
|
| 161 |
+
ax.set_xlabel("Grader Score", fontsize=12)
|
| 162 |
+
ax.set_ylabel("Frequency", fontsize=12)
|
| 163 |
+
ax.set_title("Score Distribution - Baseline vs. GRPO-Trained", fontsize=14, fontweight="bold")
|
| 164 |
+
ax.legend(fontsize=11)
|
| 165 |
+
ax.grid(axis="y", alpha=0.3)
|
| 166 |
+
plt.tight_layout()
|
| 167 |
+
path = out_dir / "score_distribution.png"
|
| 168 |
+
fig.savefig(path, dpi=150, bbox_inches="tight")
|
| 169 |
+
plt.close(fig)
|
| 170 |
+
print(f"Saved distribution plot -> {path}")
|
| 171 |
+
return path
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def save_eval_results(
|
| 175 |
+
baseline_results: List[Dict[str, Any]],
|
| 176 |
+
trained_results: List[Dict[str, Any]],
|
| 177 |
+
out_dir: Path = PLOTS_DIR,
|
| 178 |
+
) -> Path:
|
| 179 |
+
"""Save evaluation results as JSON."""
|
| 180 |
+
data = {
|
| 181 |
+
"baseline": baseline_results,
|
| 182 |
+
"trained": trained_results,
|
| 183 |
+
"summary": {
|
| 184 |
+
"baseline_mean": float(np.mean([r["score"] for r in baseline_results])) if baseline_results else 0,
|
| 185 |
+
"trained_mean": float(np.mean([r["score"] for r in trained_results])) if trained_results else 0,
|
| 186 |
+
"improvement": float(np.mean([r["score"] for r in trained_results]) - np.mean([r["score"] for r in baseline_results])) if baseline_results and trained_results else 0,
|
| 187 |
+
},
|
| 188 |
+
}
|
| 189 |
+
path = out_dir / "eval_results.json"
|
| 190 |
+
path.write_text(json.dumps(data, indent=2))
|
| 191 |
+
print(f"Saved eval results -> {path}")
|
| 192 |
+
return path
|
training/reward.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reward functions for GRPO training.
|
| 3 |
+
|
| 4 |
+
Uses the real domain graders from the Agentic RAG Gym so the RL signal
|
| 5 |
+
matches the actual evaluation rubric — not a proxy.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import asyncio
|
| 11 |
+
from datetime import datetime, timezone
|
| 12 |
+
from typing import Any, Dict, List
|
| 13 |
+
|
| 14 |
+
from domains.aerospace.graders import GRADER_REGISTRY
|
| 15 |
+
from rag_master.models import EpisodeState, StepRecord, Trajectory
|
| 16 |
+
from rag_master.rewards import _SCORE_MIN
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _make_dummy_state(task_id: str, answer: str) -> EpisodeState:
|
| 20 |
+
"""Minimal EpisodeState for offline grading."""
|
| 21 |
+
from domains.aerospace.config import AerospaceDomainConfig
|
| 22 |
+
|
| 23 |
+
domain = AerospaceDomainConfig()
|
| 24 |
+
tasks = {t.task_id: t for t in domain.get_tasks()}
|
| 25 |
+
task = tasks.get(task_id)
|
| 26 |
+
if task is None:
|
| 27 |
+
raise ValueError(f"Unknown task_id: {task_id}")
|
| 28 |
+
|
| 29 |
+
return EpisodeState(
|
| 30 |
+
episode_id="grpo-eval",
|
| 31 |
+
task=task,
|
| 32 |
+
current_step=5,
|
| 33 |
+
query_history=["query"],
|
| 34 |
+
retrieved_docs=[],
|
| 35 |
+
agent_messages=[],
|
| 36 |
+
generated_answer=answer,
|
| 37 |
+
intermediate_rewards=[0.5] * 5,
|
| 38 |
+
done=True,
|
| 39 |
+
info={},
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _make_dummy_trajectory(task_id: str) -> Trajectory:
|
| 44 |
+
"""Minimal trajectory with a good action sequence for process scoring."""
|
| 45 |
+
now = datetime.now(timezone.utc)
|
| 46 |
+
steps = [
|
| 47 |
+
StepRecord(step_index=0, action_type="plan", action_payload={},
|
| 48 |
+
observation_summary="planned", intermediate_reward=0.5,
|
| 49 |
+
reasoning_trace="Planning approach.", timestamp=now),
|
| 50 |
+
StepRecord(step_index=1, action_type="retrieve", action_payload={},
|
| 51 |
+
observation_summary="retrieved", intermediate_reward=0.6,
|
| 52 |
+
reasoning_trace="Retrieving.", timestamp=now),
|
| 53 |
+
StepRecord(step_index=2, action_type="reason", action_payload={},
|
| 54 |
+
observation_summary="reasoned", intermediate_reward=0.5,
|
| 55 |
+
reasoning_trace="Analyzing because data is relevant.", timestamp=now),
|
| 56 |
+
StepRecord(step_index=3, action_type="answer", action_payload={},
|
| 57 |
+
observation_summary="answered", intermediate_reward=0.6,
|
| 58 |
+
reasoning_trace="Final answer.", timestamp=now),
|
| 59 |
+
StepRecord(step_index=4, action_type="verify", action_payload={},
|
| 60 |
+
observation_summary="verified", intermediate_reward=0.5,
|
| 61 |
+
reasoning_trace="Verifying.", timestamp=now),
|
| 62 |
+
]
|
| 63 |
+
return Trajectory(
|
| 64 |
+
episode_id="grpo-eval", task_id=task_id, steps=steps,
|
| 65 |
+
total_reward=0.0, final_score=0.0, completed=True, metadata={},
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def grade_answer_sync(task_id: str, answer: str) -> float:
|
| 70 |
+
"""Grade a single answer using the domain grader (synchronous)."""
|
| 71 |
+
grader_cls = GRADER_REGISTRY.get(task_id)
|
| 72 |
+
if grader_cls is None:
|
| 73 |
+
return float(_SCORE_MIN)
|
| 74 |
+
grader = grader_cls()
|
| 75 |
+
state = _make_dummy_state(task_id, answer)
|
| 76 |
+
trajectory = _make_dummy_trajectory(task_id)
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
loop = asyncio.get_running_loop()
|
| 80 |
+
except RuntimeError:
|
| 81 |
+
loop = None
|
| 82 |
+
|
| 83 |
+
if loop and loop.is_running():
|
| 84 |
+
import concurrent.futures
|
| 85 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
| 86 |
+
return pool.submit(asyncio.run, grader.grade(state, trajectory)).result()
|
| 87 |
+
return asyncio.run(grader.grade(state, trajectory))
|