jamesaasher's picture
Upload model.py with huggingface_hub
d6d864b verified
"""Text-conditional U-Net for diffusion."""
import torch
import torch.nn as nn
import math
import config
class TimeEmbedding(nn.Module):
"""Sinusoidal time embedding."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
return torch.cat([emb.sin(), emb.cos()], dim=1)
class ResBlock(nn.Module):
"""Residual block with time and text conditioning."""
def __init__(self, in_ch, out_ch, time_dim, text_dim=None):
super().__init__()
self.time_mlp = nn.Linear(time_dim, out_ch)
self.text_mlp = nn.Linear(text_dim, out_ch) if text_dim else None
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.norm1 = nn.GroupNorm(min(8, in_ch), in_ch)
self.norm2 = nn.GroupNorm(min(8, out_ch), out_ch)
self.act = nn.SiLU()
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t_emb, text_emb=None):
h = self.act(self.norm1(x))
h = self.conv1(h)
# Add time embedding
h = h + self.time_mlp(t_emb)[:, :, None, None]
# Add text embedding
if self.text_mlp is not None and text_emb is not None:
h = h + self.text_mlp(text_emb)[:, :, None, None]
h = self.act(self.norm2(h))
h = self.conv2(h)
return h + self.skip(x)
class TextConditionedUNet(nn.Module):
"""U-Net with CLIP text conditioning."""
def __init__(self, text_dim=512):
super().__init__()
self.text_dim = text_dim
self.time_emb = TimeEmbedding(config.TIME_DIM)
self.time_mlp = nn.Sequential(
nn.Linear(config.TIME_DIM, config.TIME_DIM),
nn.SiLU(),
nn.Linear(config.TIME_DIM, config.TIME_DIM)
)
self.text_proj = nn.Sequential(
nn.Linear(text_dim, text_dim),
nn.SiLU(),
nn.Linear(text_dim, text_dim)
)
# Down path
self.down1 = ResBlock(1, config.CHANNELS, config.TIME_DIM, text_dim)
self.down2 = ResBlock(config.CHANNELS, config.CHANNELS * 2, config.TIME_DIM, text_dim)
self.down3 = ResBlock(config.CHANNELS * 2, config.CHANNELS * 4, config.TIME_DIM, text_dim)
# Middle
self.mid = ResBlock(config.CHANNELS * 4, config.CHANNELS * 4, config.TIME_DIM, text_dim)
# Up path
self.up3 = ResBlock(config.CHANNELS * 8, config.CHANNELS * 2, config.TIME_DIM, text_dim)
self.up2 = ResBlock(config.CHANNELS * 4, config.CHANNELS, config.TIME_DIM, text_dim)
self.up1 = ResBlock(config.CHANNELS * 2, config.CHANNELS, config.TIME_DIM, text_dim)
# Output
self.out = nn.Conv2d(config.CHANNELS, 1, 1)
# Pooling/Upsampling
self.pool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
def forward(self, x, t, text_emb):
"""
Args:
x: [B, 1, H, W] noisy images
t: [B] timesteps
text_emb: [B, text_dim] CLIP text embeddings
"""
# Embeddings
t_emb = self.time_mlp(self.time_emb(t))
text_emb = self.text_proj(text_emb)
# Down
h1 = self.down1(x, t_emb, text_emb)
h2 = self.down2(self.pool(h1), t_emb, text_emb)
h3 = self.down3(self.pool(h2), t_emb, text_emb)
# Middle
h = self.mid(self.pool(h3), t_emb, text_emb)
# Up
h = self.up3(torch.cat([self.upsample(h), h3], dim=1), t_emb, text_emb)
h = self.up2(torch.cat([self.upsample(h), h2], dim=1), t_emb, text_emb)
h = self.up1(torch.cat([self.upsample(h), h1], dim=1), t_emb, text_emb)
return self.out(h)
if __name__ == "__main__":
# Test model
print("Testing Text-Conditioned U-Net...")
model = TextConditionedUNet(text_dim=512)
# Test forward pass
batch_size = 2
x = torch.randn(batch_size, 1, 64, 64)
t = torch.randint(0, 1000, (batch_size,))
text_emb = torch.randn(batch_size, 512)
out = model(x, t, text_emb)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"✅ Model test passed!")