|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
|
|
|
""" |
|
|
This scheduler has 3 main responsibilities: |
|
|
|
|
|
1. Setup (init) - Pre-compute noise schedule |
|
|
2. Training (q_sample) - Add noise to images |
|
|
3. Generation (p_sample_text + sample_text) - Remove noise |
|
|
step-by-step |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
class SimpleDDPMScheduler: |
|
|
def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02): |
|
|
self.num_timesteps = num_timesteps |
|
|
|
|
|
|
|
|
self.betas = torch.linspace(beta_start, beta_end, num_timesteps) |
|
|
self.alphas = 1.0 - self.betas |
|
|
self.alphas_cumprod = torch.cumprod( |
|
|
self.alphas, dim=0 |
|
|
) |
|
|
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) |
|
|
|
|
|
|
|
|
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) |
|
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.posterior_variance = ( |
|
|
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
|
|
) |
|
|
|
|
|
def q_sample(self, x_start, t, noise=None): |
|
|
"""Add noise to the clean images according to the noise schedule. |
|
|
|
|
|
So we can have examples at any timestep in the forward process.""" |
|
|
|
|
|
if noise is None: |
|
|
noise = torch.randn_like(x_start) |
|
|
|
|
|
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) |
|
|
sqrt_one_minus_alphas_cumprod_t = extract( |
|
|
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape |
|
|
) |
|
|
|
|
|
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise |
|
|
|
|
|
def p_sample_text(self, model, x, t, text_embeddings, guidance_scale=1.0): |
|
|
"""Sample x_{t-1} from x_t using the model with text conditioning and CFG. |
|
|
|
|
|
Args: |
|
|
model: The diffusion model |
|
|
x: Current noisy image |
|
|
t: Current timestep |
|
|
text_embeddings: Text embeddings for conditioning |
|
|
guidance_scale: Classifier-free guidance scale (1.0 = no guidance, higher = stronger) |
|
|
""" |
|
|
|
|
|
predicted_noise = model(x, t, text_embeddings) |
|
|
|
|
|
|
|
|
if guidance_scale > 1.0: |
|
|
|
|
|
uncond_embeddings = torch.zeros_like(text_embeddings) |
|
|
uncond_noise = model(x, t, uncond_embeddings) |
|
|
|
|
|
|
|
|
predicted_noise = uncond_noise + guidance_scale * (predicted_noise - uncond_noise) |
|
|
|
|
|
|
|
|
betas_t = extract(self.betas, t, x.shape) |
|
|
sqrt_one_minus_alphas_cumprod_t = extract( |
|
|
self.sqrt_one_minus_alphas_cumprod, t, x.shape |
|
|
) |
|
|
sqrt_recip_alphas_t = extract(1.0 / torch.sqrt(self.alphas), t, x.shape) |
|
|
|
|
|
|
|
|
model_mean = sqrt_recip_alphas_t * ( |
|
|
x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t |
|
|
) |
|
|
|
|
|
if t[0] == 0: |
|
|
return model_mean |
|
|
else: |
|
|
posterior_variance_t = extract(self.posterior_variance, t, x.shape) |
|
|
noise = torch.randn_like(x) |
|
|
return model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
|
|
|
def sample_text(self, model, shape, text_embeddings, device="cuda", guidance_scale=1.0): |
|
|
"""Generate samples using DDPM sampling with text conditioning and CFG. |
|
|
|
|
|
Args: |
|
|
model: The diffusion model |
|
|
shape: Output shape (B, C, H, W) |
|
|
text_embeddings: Text embeddings for conditioning |
|
|
device: Device to use |
|
|
guidance_scale: Classifier-free guidance scale (1.0 = no guidance, 3.0-7.0 typical) |
|
|
""" |
|
|
b = shape[0] |
|
|
img = torch.randn(shape, device=device) |
|
|
|
|
|
for i in reversed(range(0, self.num_timesteps)): |
|
|
t = torch.full((b,), i, device=device, dtype=torch.long) |
|
|
img = self.p_sample_text(model, img, t, text_embeddings, guidance_scale) |
|
|
|
|
|
|
|
|
img = torch.clamp(img, -2.0, 2.0) |
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
def extract(a, t, x_shape): |
|
|
"""Extract coefficients from a based on t and reshape to match x_shape.""" |
|
|
batch_size = t.shape[0] |
|
|
out = a.gather(-1, t.cpu()) |
|
|
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
|
|
|