jamesaasher commited on
Commit
d6d864b
·
verified ·
1 Parent(s): 231be4e

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +137 -0
model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text-conditional U-Net for diffusion."""
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import config
6
+
7
+
8
+ class TimeEmbedding(nn.Module):
9
+ """Sinusoidal time embedding."""
10
+
11
+ def __init__(self, dim):
12
+ super().__init__()
13
+ self.dim = dim
14
+
15
+ def forward(self, t):
16
+ half_dim = self.dim // 2
17
+ emb = math.log(10000) / (half_dim - 1)
18
+ emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
19
+ emb = t[:, None] * emb[None, :]
20
+ return torch.cat([emb.sin(), emb.cos()], dim=1)
21
+
22
+
23
+ class ResBlock(nn.Module):
24
+ """Residual block with time and text conditioning."""
25
+
26
+ def __init__(self, in_ch, out_ch, time_dim, text_dim=None):
27
+ super().__init__()
28
+ self.time_mlp = nn.Linear(time_dim, out_ch)
29
+ self.text_mlp = nn.Linear(text_dim, out_ch) if text_dim else None
30
+
31
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
32
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
33
+ self.norm1 = nn.GroupNorm(min(8, in_ch), in_ch)
34
+ self.norm2 = nn.GroupNorm(min(8, out_ch), out_ch)
35
+ self.act = nn.SiLU()
36
+
37
+ self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
38
+
39
+ def forward(self, x, t_emb, text_emb=None):
40
+ h = self.act(self.norm1(x))
41
+ h = self.conv1(h)
42
+
43
+ # Add time embedding
44
+ h = h + self.time_mlp(t_emb)[:, :, None, None]
45
+
46
+ # Add text embedding
47
+ if self.text_mlp is not None and text_emb is not None:
48
+ h = h + self.text_mlp(text_emb)[:, :, None, None]
49
+
50
+ h = self.act(self.norm2(h))
51
+ h = self.conv2(h)
52
+
53
+ return h + self.skip(x)
54
+
55
+
56
+ class TextConditionedUNet(nn.Module):
57
+ """U-Net with CLIP text conditioning."""
58
+
59
+ def __init__(self, text_dim=512):
60
+ super().__init__()
61
+ self.text_dim = text_dim
62
+
63
+ self.time_emb = TimeEmbedding(config.TIME_DIM)
64
+ self.time_mlp = nn.Sequential(
65
+ nn.Linear(config.TIME_DIM, config.TIME_DIM),
66
+ nn.SiLU(),
67
+ nn.Linear(config.TIME_DIM, config.TIME_DIM)
68
+ )
69
+
70
+ self.text_proj = nn.Sequential(
71
+ nn.Linear(text_dim, text_dim),
72
+ nn.SiLU(),
73
+ nn.Linear(text_dim, text_dim)
74
+ )
75
+
76
+ # Down path
77
+ self.down1 = ResBlock(1, config.CHANNELS, config.TIME_DIM, text_dim)
78
+ self.down2 = ResBlock(config.CHANNELS, config.CHANNELS * 2, config.TIME_DIM, text_dim)
79
+ self.down3 = ResBlock(config.CHANNELS * 2, config.CHANNELS * 4, config.TIME_DIM, text_dim)
80
+
81
+ # Middle
82
+ self.mid = ResBlock(config.CHANNELS * 4, config.CHANNELS * 4, config.TIME_DIM, text_dim)
83
+
84
+ # Up path
85
+ self.up3 = ResBlock(config.CHANNELS * 8, config.CHANNELS * 2, config.TIME_DIM, text_dim)
86
+ self.up2 = ResBlock(config.CHANNELS * 4, config.CHANNELS, config.TIME_DIM, text_dim)
87
+ self.up1 = ResBlock(config.CHANNELS * 2, config.CHANNELS, config.TIME_DIM, text_dim)
88
+
89
+ # Output
90
+ self.out = nn.Conv2d(config.CHANNELS, 1, 1)
91
+
92
+ # Pooling/Upsampling
93
+ self.pool = nn.MaxPool2d(2)
94
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
95
+
96
+ def forward(self, x, t, text_emb):
97
+ """
98
+ Args:
99
+ x: [B, 1, H, W] noisy images
100
+ t: [B] timesteps
101
+ text_emb: [B, text_dim] CLIP text embeddings
102
+ """
103
+ # Embeddings
104
+ t_emb = self.time_mlp(self.time_emb(t))
105
+ text_emb = self.text_proj(text_emb)
106
+
107
+ # Down
108
+ h1 = self.down1(x, t_emb, text_emb)
109
+ h2 = self.down2(self.pool(h1), t_emb, text_emb)
110
+ h3 = self.down3(self.pool(h2), t_emb, text_emb)
111
+
112
+ # Middle
113
+ h = self.mid(self.pool(h3), t_emb, text_emb)
114
+
115
+ # Up
116
+ h = self.up3(torch.cat([self.upsample(h), h3], dim=1), t_emb, text_emb)
117
+ h = self.up2(torch.cat([self.upsample(h), h2], dim=1), t_emb, text_emb)
118
+ h = self.up1(torch.cat([self.upsample(h), h1], dim=1), t_emb, text_emb)
119
+
120
+ return self.out(h)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ # Test model
125
+ print("Testing Text-Conditioned U-Net...")
126
+ model = TextConditionedUNet(text_dim=512)
127
+
128
+ # Test forward pass
129
+ batch_size = 2
130
+ x = torch.randn(batch_size, 1, 64, 64)
131
+ t = torch.randint(0, 1000, (batch_size,))
132
+ text_emb = torch.randn(batch_size, 512)
133
+
134
+ out = model(x, t, text_emb)
135
+ print(f"Input shape: {x.shape}")
136
+ print(f"Output shape: {out.shape}")
137
+ print(f"✅ Model test passed!")