jamesaasher's picture
Upload README.md with huggingface_hub
231be4e verified
---
license: mit
tags:
- diffusion
- text-to-image
- quickdraw
- pytorch
- clip
- ddpm
language:
- en
datasets:
- Xenova/quickdraw-small
---
# Text-Conditional QuickDraw Diffusion Model
A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches.
## Model Description
This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses:
- **CLIP text encoder** (`openai/clip-vit-base-patch32`) for text conditioning
- **DDPM** for the diffusion process (1000 timesteps)
- **Classifier-free guidance** for improved text-image alignment
- Trained on **Google QuickDraw** dataset
## Model Details
- **Model Type**: Text-conditional DDPM diffusion model
- **Architecture**: U-Net with cross-attention for text conditioning
- **Image Size**: 64x64 grayscale
- **Base Channels**: 256
- **Text Encoder**: CLIP ViT-B/32 (frozen)
- **Training Steps**: 100 epochs
- **Diffusion Timesteps**: 1000
- **Guidance Scale**: 5.0 (default)
### Training Configuration
- **Dataset**: Xenova/quickdraw-small (5 classes)
- **Batch Size**: 128 (32 per GPU Γ— 4 GPUs)
- **Learning Rate**: 1e-4
- **CFG Drop Probability**: 0.15
- **Optimizer**: Adam
## Usage
### Installation
```bash
pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm
```
### Generate Images
```python
import torch
from model import TextConditionedUNet
from scheduler import SimpleDDPMScheduler
from text_encoder import CLIPTextEncoder
from generate import generate_samples
# Load checkpoint
checkpoint_path = "text_diffusion_final_epoch_100.pt"
checkpoint = torch.load(checkpoint_path)
# Initialize model
model = TextConditionedUNet(text_dim=512).cuda()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Initialize text encoder
text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda()
text_encoder.eval()
# Generate samples
scheduler = SimpleDDPMScheduler(1000)
prompt = "a drawing of a cat"
num_samples = 4
guidance_scale = 5.0
with torch.no_grad():
text_embedding = text_encoder(prompt)
text_embeddings = text_embedding.repeat(num_samples, 1)
shape = (num_samples, 1, 64, 64)
samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale)
```
### Command Line Usage
```bash
# Generate samples
python generate.py --checkpoint text_diffusion_final_epoch_100.pt \
--prompt "a drawing of a fire truck" \
--num-samples 4 \
--guidance-scale 5.0
# Visualize denoising process
python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \
--prompt "a drawing of a cat" \
--num-steps 10
```
## Example Prompts
Try these prompts for best results:
- "a drawing of a cat"
- "a drawing of a fire truck"
- "a drawing of an airplane"
- "a drawing of a house"
- "a drawing of a tree"
**Note**: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]".
## Classifier-Free Guidance
The model supports classifier-free guidance to improve text-image alignment:
- `guidance_scale = 1.0`: No guidance (pure conditional generation)
- `guidance_scale = 3.0-7.0`: Recommended range (default: 5.0)
- Higher values: Stronger adherence to text prompt (may reduce diversity)
## Model Architecture
### U-Net Structure
```
Input: (batch, 1, 64, 64)
β”œβ”€β”€ Down Block 1: 1 β†’ 256 channels
β”œβ”€β”€ Down Block 2: 256 β†’ 512 channels
β”œβ”€β”€ Down Block 3: 512 β†’ 512 channels
β”œβ”€β”€ Middle Block: 512 channels
β”œβ”€β”€ Up Block 3: 1024 β†’ 512 channels (with skip connections)
β”œβ”€β”€ Up Block 2: 768 β†’ 256 channels (with skip connections)
└── Up Block 1: 512 β†’ 1 channel (with skip connections)
Output: (batch, 1, 64, 64) - predicted noise
```
### Text Conditioning
- Text prompts encoded via CLIP ViT-B/32
- 512-dimensional text embeddings
- Injected into U-Net via cross-attention
- Classifier-free guidance with 15% dropout during training
## Training Details
- **Framework**: PyTorch 2.0+
- **Hardware**: 4x NVIDIA GPUs
- **Training Time**: ~100 epochs
- **Dataset**: Google QuickDraw sketches (5 classes)
- **Noise Schedule**: Linear (Ξ² from 0.0001 to 0.02)
## Limitations
- Limited to 64x64 resolution
- Grayscale output only
- Best performance on simple objects from QuickDraw classes
- May not generalize well to complex or out-of-distribution prompts
## Citation
If you use this model, please cite:
```bibtex
@misc{quickdraw-text-diffusion,
title={Text-Conditional QuickDraw Diffusion Model},
author={Your Name},
year={2024},
howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}}
}
```
## License
MIT License
## Acknowledgments
- Google QuickDraw dataset
- OpenAI CLIP
- DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
- Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)