deepakdsoni commited on
Commit
b7f4634
·
verified ·
1 Parent(s): d7eb6ca

Initial upload: DPMM-0.1B-MoE (124.5M params, 16/16 validation pass)

Browse files
README.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ tags:
6
+ - mixture-of-experts
7
+ - moe
8
+ - causal-lm
9
+ - custom-architecture
10
+ - from-scratch
11
+ - gqa
12
+ - rope
13
+ - swiglu
14
+ - dora
15
+ - small-model
16
+ - educational
17
+ library_name: transformers
18
+ pipeline_tag: text-generation
19
+ model-index:
20
+ - name: DPMM-0.1B-MoE
21
+ results:
22
+ - task:
23
+ type: text-generation
24
+ metrics:
25
+ - name: Validation Pass Rate
26
+ type: accuracy
27
+ value: 100
28
+ verified: false
29
+ ---
30
+
31
+ # DPMM-0.1B-MoE
32
+
33
+ A 124.5M parameter Mixture-of-Experts language model trained from scratch with production-grade architecture techniques.
34
+
35
+ ## Model Description
36
+
37
+ DPMM (Differentiable Probabilistic Mixture Model) is a custom Transformer + MoE architecture implementing state-of-the-art techniques from DeepSeek-V3, Gemma 2, Qwen3, and Llama 3. Built as an educational reference for the AI community — demonstrating that the **entire LLM training pipeline** (pre-training, SFT, alignment, safety) can be implemented from scratch on modest hardware.
38
+
39
+ ### Architecture
40
+
41
+ | Component | Specification |
42
+ |-----------|---------------|
43
+ | Parameters | 124.5M total |
44
+ | Hidden Size | 512 |
45
+ | Layers | 8 |
46
+ | Attention | GQA (8 heads, 2 KV heads) |
47
+ | Head Dim | 64 |
48
+ | FFN | SwiGLU (1408 intermediate) |
49
+ | Experts | 4 routed + 1 shared |
50
+ | Top-K | 2 experts per token |
51
+ | Routing | DeepSeek-V3 auxiliary-loss-free |
52
+ | Position | RoPE (theta=500K) |
53
+ | Norm | RMSNorm + QK-Norm |
54
+ | Vocab | 32,000 (SentencePiece) |
55
+ | Max Seq | 2,048 tokens |
56
+
57
+ ### Key Techniques
58
+
59
+ - **Grouped Query Attention (GQA)** — 4:1 Q/KV ratio reduces KV cache by 4x
60
+ - **QK-Norm** — Per-head RMS normalization prevents attention logit growth (Gemma 2, DeepSeek-V3)
61
+ - **Auxiliary-Loss-Free Routing** — Expert load balancing via bias adjustment, not auxiliary loss (DeepSeek-V3)
62
+ - **SwiGLU Activation** — Gate + Up + Down projection (Llama/Mixtral/Qwen3)
63
+ - **Embedding Scaling** — Multiply embeddings by sqrt(d_model) (Gemma, Qwen3)
64
+ - **Residual Scaling** — Output projections scaled by 1/sqrt(2L) for training stability
65
+ - **RoPE** — Rotary Position Embeddings with high theta (500K) for length extrapolation
66
+ - **DoRA + RS-LoRA** — Weight-Decomposed Rank-Stabilized adaptation for fine-tuning
67
+
68
+ ## Training
69
+
70
+ ### Phase 1 — Combined SFT (~60 min on 2x A10)
71
+
72
+ | Dataset | Examples | Purpose |
73
+ |---------|----------|---------|
74
+ | Alpaca | 10,000 | General instruction following |
75
+ | Code/DevOps | 800 | Python, Kubernetes, Docker, CUDA, CI/CD |
76
+ | Customer Support | 800 | Ticket classification, troubleshooting |
77
+ | Legal | 800 | Contract analysis, compliance, IP |
78
+ | Finance | 800 | ROI, portfolio, risk analysis |
79
+
80
+ Loss: 2.73 → 1.74 | LR: 1e-5 | 5 epochs
81
+
82
+ ### Phase 2 — Balanced Alignment (~10 min on 2x A10)
83
+
84
+ | Dataset | Examples | % of Total | Purpose |
85
+ |---------|----------|------------|---------|
86
+ | Guard/Safety | 800 | 29% | PII detection, injection blocking |
87
+ | Domain Replay | 1,120 | 40% | Preserve Phase 1 capabilities |
88
+ | Reasoning (CoT) | 480 | 17% | Chain-of-thought math |
89
+ | Constitutional AI | 400 | 14% | Harmful request refusal |
90
+
91
+ Loss: 4.10 → 0.22 | LR: 3e-6 (cosine decay) | 4 epochs
92
+
93
+ **Key technique:** Domain Replay (40% of Phase 2 data) prevents catastrophic forgetting in small models.
94
+
95
+ ## Validation Results
96
+
97
+ **16/16 tests passing (100%)** across 9 capability categories:
98
+
99
+ | Capability | Tests | Status |
100
+ |------------|-------|--------|
101
+ | General Chat | 2 | PASS |
102
+ | Code/DevOps | 2 | PASS |
103
+ | Customer Support | 2 | PASS |
104
+ | Legal | 1 | PASS |
105
+ | Finance | 1 | PASS |
106
+ | Reasoning (CoT) | 2 | PASS |
107
+ | Multilingual | 2 | PASS |
108
+ | Guard/Safety | 2 | PASS |
109
+ | Constitutional AI | 2 | PASS |
110
+
111
+ ## Usage
112
+
113
+ ```python
114
+ from transformers import AutoModelForCausalLM, AutoTokenizer
115
+
116
+ model = AutoModelForCausalLM.from_pretrained(
117
+ "deepakdsoni/DPMM-0.1B-MoE",
118
+ trust_remote_code=True,
119
+ torch_dtype="auto",
120
+ )
121
+ tokenizer = AutoTokenizer.from_pretrained("deepakdsoni/DPMM-0.1B-MoE")
122
+
123
+ prompt = "### Instruction:\nExplain what a REST API is.\n\n### Response:\n"
124
+ inputs = tokenizer(prompt, return_tensors="pt")
125
+ outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.7, top_p=0.9)
126
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
127
+ ```
128
+
129
+ ### Prompt Formats
130
+
131
+ The model responds to these trained prompt templates:
132
+
133
+ ```
134
+ ### Instruction:\n{question}\n\n### Response:\n
135
+ ### Programming Question:\n{question}\n\n### Solution:\n
136
+ ### Support Ticket:\n{issue}\n\n### Agent Response:\n
137
+ ### Legal Question:\n{question}\n\n### Legal Analysis:\n
138
+ ### Finance Question:\n{question}\n\n### Analysis:\n
139
+ ### Guard Classification:\n{input}\n\n### Classification:\n
140
+ ### Constitutional Check:\n{request}\n\n### Response:\n
141
+ ```
142
+
143
+ ## Limitations
144
+
145
+ ### What 125M Parameters Can Do
146
+ - Follow specific trained prompt formats
147
+ - Produce domain-appropriate structured responses
148
+ - Classify inputs (guard, safety, priority)
149
+ - Simple mathematical reasoning with chain-of-thought
150
+ - Refuse harmful requests
151
+
152
+ ### What 125M Parameters Cannot Do
153
+ - Generalize to unseen prompt formats
154
+ - Produce long coherent text (quality degrades after ~100 tokens)
155
+ - Handle abstract reasoning or analogies
156
+ - Generate creative or novel content
157
+
158
+ ## Hardware Requirements
159
+
160
+ - **Training:** 2x NVIDIA A10 (23GB each), ~70 minutes total
161
+ - **Inference:** Any GPU with 1GB+ VRAM, or CPU (slow)
162
+ - **GGUF quantized:** Runs on consumer hardware (laptop CPU)
163
+
164
+ ## Citation
165
+
166
+ ```bibtex
167
+ @misc{dpmm-0.1b-moe-2025,
168
+ title={DPMM-0.1B-MoE: A Small Mixture-of-Experts Language Model},
169
+ author={Deepak Soni},
170
+ year={2025},
171
+ url={https://huggingface.co/deepakdsoni/DPMM-0.1B-MoE}
172
+ }
173
+ ```
174
+
175
+ ## License
176
+
177
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["DPMMForCausalLM"],
3
+ "model_type": "dpmm",
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_dpmm.DPMMConfig",
6
+ "AutoModelForCausalLM": "modeling_dpmm.DPMMForCausalLM"
7
+ },
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "pad_token_id": 0,
11
+ "hidden_size": 512,
12
+ "intermediate_size": 1408,
13
+ "num_attention_heads": 8,
14
+ "num_key_value_heads": 2,
15
+ "head_dim": 64,
16
+ "num_hidden_layers": 8,
17
+ "vocab_size": 32000,
18
+ "max_position_embeddings": 2048,
19
+ "rope_theta": 500000.0,
20
+ "rms_norm_eps": 1e-6,
21
+ "tie_word_embeddings": true,
22
+ "embedding_scale": true,
23
+ "qk_norm": true,
24
+ "z_loss_weight": 1e-5,
25
+ "scale_residual": true,
26
+ "moe_num_experts": 4,
27
+ "moe_num_shared_experts": 1,
28
+ "moe_top_k": 2,
29
+ "moe_router_type": "aux_loss_free",
30
+ "moe_router_bias_lr": 0.01,
31
+ "hidden_act": "silu",
32
+ "torch_dtype": "bfloat16",
33
+ "transformers_version": "4.45.0"
34
+ }
configuration_dpmm.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DPMM-0.1B-MoE configuration for Hugging Face Transformers."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class DPMMConfig(PretrainedConfig):
7
+ model_type = "dpmm"
8
+
9
+ def __init__(
10
+ self,
11
+ hidden_size=512,
12
+ intermediate_size=1408,
13
+ num_attention_heads=8,
14
+ num_key_value_heads=2,
15
+ head_dim=64,
16
+ num_hidden_layers=8,
17
+ vocab_size=32000,
18
+ max_position_embeddings=2048,
19
+ rope_theta=500000.0,
20
+ rms_norm_eps=1e-6,
21
+ tie_word_embeddings=True,
22
+ embedding_scale=True,
23
+ qk_norm=True,
24
+ z_loss_weight=1e-5,
25
+ scale_residual=True,
26
+ moe_num_experts=4,
27
+ moe_num_shared_experts=1,
28
+ moe_top_k=2,
29
+ moe_router_type="aux_loss_free",
30
+ moe_router_bias_lr=0.01,
31
+ hidden_act="silu",
32
+ bos_token_id=1,
33
+ eos_token_id=2,
34
+ pad_token_id=0,
35
+ **kwargs,
36
+ ):
37
+ self.hidden_size = hidden_size
38
+ self.intermediate_size = intermediate_size
39
+ self.num_attention_heads = num_attention_heads
40
+ self.num_key_value_heads = num_key_value_heads
41
+ self.head_dim = head_dim
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.vocab_size = vocab_size
44
+ self.max_position_embeddings = max_position_embeddings
45
+ self.rope_theta = rope_theta
46
+ self.rms_norm_eps = rms_norm_eps
47
+ self.embedding_scale = embedding_scale
48
+ self.qk_norm = qk_norm
49
+ self.z_loss_weight = z_loss_weight
50
+ self.scale_residual = scale_residual
51
+ self.moe_num_experts = moe_num_experts
52
+ self.moe_num_shared_experts = moe_num_shared_experts
53
+ self.moe_top_k = moe_top_k
54
+ self.moe_router_type = moe_router_type
55
+ self.moe_router_bias_lr = moe_router_bias_lr
56
+ self.hidden_act = hidden_act
57
+ super().__init__(
58
+ bos_token_id=bos_token_id,
59
+ eos_token_id=eos_token_id,
60
+ pad_token_id=pad_token_id,
61
+ tie_word_embeddings=tie_word_embeddings,
62
+ **kwargs,
63
+ )
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": 2,
4
+ "pad_token_id": 0,
5
+ "do_sample": true,
6
+ "temperature": 0.7,
7
+ "top_p": 0.9,
8
+ "top_k": 50,
9
+ "repetition_penalty": 1.1,
10
+ "max_new_tokens": 256,
11
+ "transformers_version": "4.45.0"
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5f32db5455515adb41b17dd61b59dfd9100ecebbc0f45b35d493cf35c5c0f6a
3
+ size 249110448
modeling_dpmm.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DPMM-0.1B-MoE model implementation for Hugging Face Transformers.
2
+
3
+ Architecture: Transformer + Mixture of Experts (Shared + Routed)
4
+ - GQA (Grouped Query Attention) with RoPE
5
+ - QK-Norm (Gemma 2 / DeepSeek-V3 style)
6
+ - SwiGLU experts with DeepSeek-V3 auxiliary-loss-free routing
7
+ - Embedding scaling (sqrt(d_model))
8
+ - Residual output projection scaling (1/sqrt(2L))
9
+ """
10
+
11
+ import math
12
+ from typing import Optional, Tuple, List
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch import Tensor
18
+ from transformers import PreTrainedModel
19
+ from transformers.modeling_outputs import CausalLMOutputWithPast
20
+
21
+ from .configuration_dpmm import DPMMConfig
22
+
23
+
24
+ class RMSNorm(nn.Module):
25
+ def __init__(self, dim: int, eps: float = 1e-6):
26
+ super().__init__()
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+ self.eps = eps
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
32
+ return (x * norm).to(x.dtype) * self.weight
33
+
34
+
35
+ def precompute_rope_freqs(dim: int, max_seq_len: int, theta: float = 500000.0):
36
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
37
+ t = torch.arange(max_seq_len, dtype=torch.float32)
38
+ angles = torch.outer(t, freqs)
39
+ return angles.cos(), angles.sin()
40
+
41
+
42
+ def _rotate_half(x: Tensor) -> Tensor:
43
+ x1 = x[..., : x.shape[-1] // 2]
44
+ x2 = x[..., x.shape[-1] // 2 :]
45
+ return torch.cat((-x2, x1), dim=-1)
46
+
47
+
48
+ def apply_rope(x: Tensor, rope_cos: Tensor, rope_sin: Tensor) -> Tensor:
49
+ seq_len = x.shape[1]
50
+ cos = rope_cos[:seq_len].unsqueeze(0).unsqueeze(2)
51
+ sin = rope_sin[:seq_len].unsqueeze(0).unsqueeze(2)
52
+ cos = torch.cat([cos, cos], dim=-1)
53
+ sin = torch.cat([sin, sin], dim=-1)
54
+ return (x.float() * cos + _rotate_half(x.float()) * sin).to(x.dtype)
55
+
56
+
57
+ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
58
+ if n_rep == 1:
59
+ return x
60
+ bs, seq, n_kv, d = x.shape
61
+ return x[:, :, :, None, :].expand(bs, seq, n_kv, n_rep, d).reshape(bs, seq, n_kv * n_rep, d)
62
+
63
+
64
+ class HeadRMSNorm(nn.Module):
65
+ def __init__(self, d_head: int, eps: float = 1e-6):
66
+ super().__init__()
67
+ self.weight = nn.Parameter(torch.ones(d_head))
68
+ self.eps = eps
69
+
70
+ def forward(self, x: Tensor) -> Tensor:
71
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
72
+ return (x * norm).to(x.dtype) * self.weight
73
+
74
+
75
+ class GQAttention(nn.Module):
76
+ def __init__(self, config: DPMMConfig):
77
+ super().__init__()
78
+ self.n_heads = config.num_attention_heads
79
+ self.n_kv_heads = config.num_key_value_heads
80
+ self.d_head = config.head_dim
81
+ self.n_rep = self.n_heads // self.n_kv_heads
82
+
83
+ self.wq = nn.Linear(config.hidden_size, self.n_heads * self.d_head, bias=False)
84
+ self.wk = nn.Linear(config.hidden_size, self.n_kv_heads * self.d_head, bias=False)
85
+ self.wv = nn.Linear(config.hidden_size, self.n_kv_heads * self.d_head, bias=False)
86
+ self.wo = nn.Linear(self.n_heads * self.d_head, config.hidden_size, bias=False)
87
+
88
+ self.q_norm = HeadRMSNorm(self.d_head) if config.qk_norm else None
89
+ self.k_norm = HeadRMSNorm(self.d_head) if config.qk_norm else None
90
+
91
+ def forward(self, x: Tensor, rope_cos: Tensor, rope_sin: Tensor,
92
+ mask: Optional[Tensor] = None) -> Tensor:
93
+ bs, seq_len, _ = x.shape
94
+
95
+ q = self.wq(x).view(bs, seq_len, self.n_heads, self.d_head)
96
+ k = self.wk(x).view(bs, seq_len, self.n_kv_heads, self.d_head)
97
+ v = self.wv(x).view(bs, seq_len, self.n_kv_heads, self.d_head)
98
+
99
+ if self.q_norm is not None:
100
+ q = self.q_norm(q)
101
+ k = self.k_norm(k)
102
+
103
+ q = apply_rope(q, rope_cos, rope_sin)
104
+ k = apply_rope(k, rope_cos, rope_sin)
105
+
106
+ k = repeat_kv(k, self.n_rep)
107
+ v = repeat_kv(v, self.n_rep)
108
+
109
+ q = q.transpose(1, 2)
110
+ k = k.transpose(1, 2)
111
+ v = v.transpose(1, 2)
112
+ scale = 1.0 / math.sqrt(self.d_head)
113
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
114
+ if mask is not None:
115
+ scores = scores + mask
116
+ attn = torch.softmax(scores, dim=-1)
117
+ out = torch.matmul(attn, v)
118
+ out = out.transpose(1, 2).contiguous()
119
+ return self.wo(out.reshape(bs, seq_len, -1))
120
+
121
+
122
+ class SwiGLUExpert(nn.Module):
123
+ def __init__(self, d_model: int, d_ffn: int):
124
+ super().__init__()
125
+ self.w_gate = nn.Linear(d_model, d_ffn, bias=False)
126
+ self.w_up = nn.Linear(d_model, d_ffn, bias=False)
127
+ self.w_down = nn.Linear(d_ffn, d_model, bias=False)
128
+
129
+ def forward(self, x: Tensor) -> Tensor:
130
+ return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
131
+
132
+
133
+ class MoERouter(nn.Module):
134
+ def __init__(self, config: DPMMConfig):
135
+ super().__init__()
136
+ self.n_experts = config.moe_num_experts
137
+ self.top_k = config.moe_top_k
138
+ self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
139
+ self.register_buffer("expert_bias", torch.zeros(config.moe_num_experts))
140
+
141
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
142
+ logits = self.gate(x)
143
+ scores = F.softmax(logits, dim=-1)
144
+ adjusted = scores + self.expert_bias.detach()
145
+ top_k_vals, top_k_idx = torch.topk(adjusted, self.top_k, dim=-1)
146
+ top_k_weights = torch.gather(scores, 1, top_k_idx)
147
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-8)
148
+ return top_k_weights, top_k_idx
149
+
150
+
151
+ class MoELayer(nn.Module):
152
+ def __init__(self, config: DPMMConfig):
153
+ super().__init__()
154
+ self.n_experts = config.moe_num_experts
155
+ self.top_k = config.moe_top_k
156
+
157
+ self.shared_experts = nn.ModuleList([
158
+ SwiGLUExpert(config.hidden_size, config.intermediate_size)
159
+ for _ in range(config.moe_num_shared_experts)
160
+ ])
161
+ self.routed_experts = nn.ModuleList([
162
+ SwiGLUExpert(config.hidden_size, config.intermediate_size)
163
+ for _ in range(config.moe_num_experts)
164
+ ])
165
+ self.router = MoERouter(config)
166
+
167
+ def forward(self, x: Tensor) -> Tensor:
168
+ bs, seq_len, d = x.shape
169
+ flat_x = x.reshape(-1, d)
170
+
171
+ shared_out = sum(expert(flat_x) for expert in self.shared_experts)
172
+
173
+ weights, indices = self.router(flat_x)
174
+ routed_out = torch.zeros_like(flat_x)
175
+ for k in range(self.top_k):
176
+ expert_idx = indices[:, k]
177
+ expert_w = weights[:, k]
178
+ for e in range(self.n_experts):
179
+ mask = expert_idx == e
180
+ if mask.any():
181
+ token_input = flat_x[mask]
182
+ token_output = self.routed_experts[e](token_input)
183
+ routed_out[mask] += expert_w[mask].unsqueeze(-1) * token_output
184
+
185
+ return (shared_out + routed_out).reshape(bs, seq_len, d)
186
+
187
+
188
+ class TransformerBlock(nn.Module):
189
+ def __init__(self, config: DPMMConfig):
190
+ super().__init__()
191
+ self.attn_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
192
+ self.attention = GQAttention(config)
193
+ self.ffn_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
194
+ self.moe = MoELayer(config)
195
+
196
+ def forward(self, x: Tensor, rope_cos: Tensor, rope_sin: Tensor,
197
+ mask: Optional[Tensor] = None) -> Tensor:
198
+ h = x + self.attention(self.attn_norm(x), rope_cos, rope_sin, mask)
199
+ out = h + self.moe(self.ffn_norm(h))
200
+ return out
201
+
202
+
203
+ class DPMMForCausalLM(PreTrainedModel):
204
+ config_class = DPMMConfig
205
+ supports_gradient_checkpointing = True
206
+ _no_split_modules = ["TransformerBlock"]
207
+
208
+ def __init__(self, config: DPMMConfig):
209
+ super().__init__(config)
210
+ self.config = config
211
+ self.embed_scale = config.hidden_size ** 0.5 if config.embedding_scale else 1.0
212
+
213
+ self.tok_emb = nn.Embedding(config.vocab_size, config.hidden_size)
214
+ self.layers = nn.ModuleList([
215
+ TransformerBlock(config) for _ in range(config.num_hidden_layers)
216
+ ])
217
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
218
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
219
+
220
+ if config.tie_word_embeddings:
221
+ self.lm_head.weight = self.tok_emb.weight
222
+
223
+ rope_cos, rope_sin = precompute_rope_freqs(
224
+ config.head_dim, config.max_position_embeddings, config.rope_theta
225
+ )
226
+ self.register_buffer("rope_cos", rope_cos, persistent=False)
227
+ self.register_buffer("rope_sin", rope_sin, persistent=False)
228
+
229
+ self.post_init()
230
+
231
+ def get_input_embeddings(self):
232
+ return self.tok_emb
233
+
234
+ def set_input_embeddings(self, value):
235
+ self.tok_emb = value
236
+
237
+ def get_output_embeddings(self):
238
+ return self.lm_head
239
+
240
+ def set_output_embeddings(self, new_embeddings):
241
+ self.lm_head = new_embeddings
242
+
243
+ def forward(
244
+ self,
245
+ input_ids: Optional[torch.LongTensor] = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ labels: Optional[torch.LongTensor] = None,
248
+ past_key_values: Optional[List[Tuple[torch.Tensor]]] = None,
249
+ use_cache: Optional[bool] = None,
250
+ output_attentions: Optional[bool] = None,
251
+ output_hidden_states: Optional[bool] = None,
252
+ return_dict: Optional[bool] = None,
253
+ **kwargs,
254
+ ) -> CausalLMOutputWithPast:
255
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
256
+
257
+ bs, seq_len = input_ids.shape
258
+ x = self.tok_emb(input_ids) * self.embed_scale
259
+
260
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=x.device)
261
+ mask = torch.triu(mask, diagonal=1)
262
+ mask = mask.unsqueeze(0).unsqueeze(0)
263
+
264
+ for layer in self.layers:
265
+ x = layer(x, self.rope_cos, self.rope_sin, mask)
266
+
267
+ x = self.norm(x)
268
+ logits = self.lm_head(x)
269
+
270
+ loss = None
271
+ if labels is not None:
272
+ shift_logits = logits[..., :-1, :].contiguous()
273
+ shift_labels = labels[..., 1:].contiguous()
274
+ loss = F.cross_entropy(
275
+ shift_logits.view(-1, self.config.vocab_size),
276
+ shift_labels.view(-1),
277
+ ignore_index=-100,
278
+ )
279
+
280
+ if not return_dict:
281
+ output = (logits,)
282
+ return (loss,) + output if loss is not None else output
283
+
284
+ return CausalLMOutputWithPast(
285
+ loss=loss,
286
+ logits=logits,
287
+ past_key_values=None,
288
+ hidden_states=None,
289
+ attentions=None,
290
+ )
291
+
292
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
293
+ return {"input_ids": input_ids}
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "unk_token": "<unk>",
5
+ "pad_token": "<unk>"
6
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "pad_token": "<unk>",
7
+ "unk_token": "<unk>",
8
+ "model_max_length": 2048,
9
+ "clean_up_tokenization_spaces": false,
10
+ "tokenizer_class": "LlamaTokenizerFast",
11
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content'] }}\n\n### Response:\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}{% endif %}{% endfor %}"
12
+ }