|
|
"""Visualize the diffusion generation process - capture images at each timestep.""" |
|
|
import torch |
|
|
import argparse |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import config |
|
|
from model import TextConditionedUNet |
|
|
from scheduler import SimpleDDPMScheduler |
|
|
from text_encoder import CLIPTextEncoder |
|
|
from generate import tensor_to_image |
|
|
|
|
|
|
|
|
def sample_with_snapshots(scheduler, model, shape, text_embeddings, device='cuda', |
|
|
guidance_scale=1.0, snapshot_steps=None): |
|
|
"""Modified sampling that captures snapshots at specific timesteps.""" |
|
|
b = shape[0] |
|
|
img = torch.randn(shape, device=device) |
|
|
|
|
|
|
|
|
if snapshot_steps is None: |
|
|
interval = scheduler.num_timesteps // 10 |
|
|
snapshot_steps = list(range(scheduler.num_timesteps - 1, -1, -interval)) |
|
|
if 0 not in snapshot_steps: |
|
|
snapshot_steps.append(0) |
|
|
|
|
|
snapshots = {} |
|
|
|
|
|
for i in reversed(range(0, scheduler.num_timesteps)): |
|
|
t = torch.full((b,), i, device=device, dtype=torch.long) |
|
|
img = scheduler.p_sample_text(model, img, t, text_embeddings, guidance_scale) |
|
|
img = torch.clamp(img, -2.0, 2.0) |
|
|
|
|
|
if i in snapshot_steps: |
|
|
snapshots[i] = img.clone().detach() |
|
|
|
|
|
return img, snapshots |
|
|
|
|
|
|
|
|
def plot_denoising_process(snapshots, prompt, output_path, sample_idx=0): |
|
|
"""Plot snapshots side by side showing noise -> final image.""" |
|
|
timesteps = sorted(snapshots.keys(), reverse=True) |
|
|
num_steps = len(timesteps) |
|
|
|
|
|
fig, axes = plt.subplots(1, num_steps, figsize=(2.5 * num_steps, 3)) |
|
|
if num_steps == 1: |
|
|
axes = [axes] |
|
|
|
|
|
fig.suptitle(f'Denoising Process: "{prompt}"', fontsize=12, fontweight='bold') |
|
|
|
|
|
for idx, t in enumerate(timesteps): |
|
|
img_tensor = snapshots[t][sample_idx] |
|
|
img = tensor_to_image(img_tensor) |
|
|
|
|
|
axes[idx].imshow(img, cmap='gray') |
|
|
axes[idx].axis('off') |
|
|
axes[idx].set_title(f't={t}' if t > 0 else 'Final', fontsize=10) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Visualize denoising process') |
|
|
parser.add_argument('--checkpoint', type=str, required=True) |
|
|
parser.add_argument('--prompt', type=str, default="a drawing of a cat") |
|
|
parser.add_argument('--guidance-scale', type=float, default=config.CFG_GUIDANCE_SCALE) |
|
|
parser.add_argument('--num-steps', type=int, default=10, |
|
|
help='Number of snapshots to capture') |
|
|
parser.add_argument('--device', type=str, default='cuda') |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.device == 'cuda' and not torch.cuda.is_available(): |
|
|
args.device = 'cpu' |
|
|
|
|
|
|
|
|
checkpoint = torch.load(args.checkpoint, map_location=args.device) |
|
|
ckpt_config = checkpoint.get('config', {}) |
|
|
|
|
|
model = TextConditionedUNet(text_dim=ckpt_config.get('text_dim', config.TEXT_DIM)).to(args.device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
text_encoder = CLIPTextEncoder( |
|
|
model_name=ckpt_config.get('clip_model', config.CLIP_MODEL), freeze=True |
|
|
).to(args.device) |
|
|
text_encoder.eval() |
|
|
|
|
|
scheduler = SimpleDDPMScheduler(config.TIMESTEPS) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
text_embedding = text_encoder(args.prompt) |
|
|
shape = (1, 1, config.IMAGE_SIZE, config.IMAGE_SIZE) |
|
|
|
|
|
_, snapshots = sample_with_snapshots( |
|
|
scheduler, model, shape, text_embedding, args.device, args.guidance_scale |
|
|
) |
|
|
|
|
|
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
|
safe_prompt = "".join(c if c.isalnum() or c in " _" else "" for c in args.prompt)[:50] |
|
|
output_path = f"outputs/denoising_{safe_prompt}.png" |
|
|
|
|
|
plot_denoising_process(snapshots, args.prompt, output_path) |
|
|
print(f"✅ Saved visualization: {output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|