Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- diffusion
|
| 5 |
+
- text-to-image
|
| 6 |
+
- quickdraw
|
| 7 |
+
- pytorch
|
| 8 |
+
- clip
|
| 9 |
+
- ddpm
|
| 10 |
+
language:
|
| 11 |
+
- en
|
| 12 |
+
datasets:
|
| 13 |
+
- Xenova/quickdraw-small
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Text-Conditional QuickDraw Diffusion Model
|
| 17 |
+
|
| 18 |
+
A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches.
|
| 19 |
+
|
| 20 |
+
## Model Description
|
| 21 |
+
|
| 22 |
+
This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses:
|
| 23 |
+
- **CLIP text encoder** (`openai/clip-vit-base-patch32`) for text conditioning
|
| 24 |
+
- **DDPM** for the diffusion process (1000 timesteps)
|
| 25 |
+
- **Classifier-free guidance** for improved text-image alignment
|
| 26 |
+
- Trained on **Google QuickDraw** dataset
|
| 27 |
+
|
| 28 |
+
## Model Details
|
| 29 |
+
|
| 30 |
+
- **Model Type**: Text-conditional DDPM diffusion model
|
| 31 |
+
- **Architecture**: U-Net with cross-attention for text conditioning
|
| 32 |
+
- **Image Size**: 64x64 grayscale
|
| 33 |
+
- **Base Channels**: 256
|
| 34 |
+
- **Text Encoder**: CLIP ViT-B/32 (frozen)
|
| 35 |
+
- **Training Steps**: 100 epochs
|
| 36 |
+
- **Diffusion Timesteps**: 1000
|
| 37 |
+
- **Guidance Scale**: 5.0 (default)
|
| 38 |
+
|
| 39 |
+
### Training Configuration
|
| 40 |
+
|
| 41 |
+
- **Dataset**: Xenova/quickdraw-small (5 classes)
|
| 42 |
+
- **Batch Size**: 128 (32 per GPU Γ 4 GPUs)
|
| 43 |
+
- **Learning Rate**: 1e-4
|
| 44 |
+
- **CFG Drop Probability**: 0.15
|
| 45 |
+
- **Optimizer**: Adam
|
| 46 |
+
|
| 47 |
+
## Usage
|
| 48 |
+
|
| 49 |
+
### Installation
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### Generate Images
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
import torch
|
| 59 |
+
from model import TextConditionedUNet
|
| 60 |
+
from scheduler import SimpleDDPMScheduler
|
| 61 |
+
from text_encoder import CLIPTextEncoder
|
| 62 |
+
from generate import generate_samples
|
| 63 |
+
|
| 64 |
+
# Load checkpoint
|
| 65 |
+
checkpoint_path = "text_diffusion_final_epoch_100.pt"
|
| 66 |
+
checkpoint = torch.load(checkpoint_path)
|
| 67 |
+
|
| 68 |
+
# Initialize model
|
| 69 |
+
model = TextConditionedUNet(text_dim=512).cuda()
|
| 70 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 71 |
+
model.eval()
|
| 72 |
+
|
| 73 |
+
# Initialize text encoder
|
| 74 |
+
text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda()
|
| 75 |
+
text_encoder.eval()
|
| 76 |
+
|
| 77 |
+
# Generate samples
|
| 78 |
+
scheduler = SimpleDDPMScheduler(1000)
|
| 79 |
+
prompt = "a drawing of a cat"
|
| 80 |
+
num_samples = 4
|
| 81 |
+
guidance_scale = 5.0
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
text_embedding = text_encoder(prompt)
|
| 85 |
+
text_embeddings = text_embedding.repeat(num_samples, 1)
|
| 86 |
+
|
| 87 |
+
shape = (num_samples, 1, 64, 64)
|
| 88 |
+
samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Command Line Usage
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Generate samples
|
| 95 |
+
python generate.py --checkpoint text_diffusion_final_epoch_100.pt \
|
| 96 |
+
--prompt "a drawing of a fire truck" \
|
| 97 |
+
--num-samples 4 \
|
| 98 |
+
--guidance-scale 5.0
|
| 99 |
+
|
| 100 |
+
# Visualize denoising process
|
| 101 |
+
python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \
|
| 102 |
+
--prompt "a drawing of a cat" \
|
| 103 |
+
--num-steps 10
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## Example Prompts
|
| 107 |
+
|
| 108 |
+
Try these prompts for best results:
|
| 109 |
+
- "a drawing of a cat"
|
| 110 |
+
- "a drawing of a fire truck"
|
| 111 |
+
- "a drawing of an airplane"
|
| 112 |
+
- "a drawing of a house"
|
| 113 |
+
- "a drawing of a tree"
|
| 114 |
+
|
| 115 |
+
**Note**: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]".
|
| 116 |
+
|
| 117 |
+
## Classifier-Free Guidance
|
| 118 |
+
|
| 119 |
+
The model supports classifier-free guidance to improve text-image alignment:
|
| 120 |
+
- `guidance_scale = 1.0`: No guidance (pure conditional generation)
|
| 121 |
+
- `guidance_scale = 3.0-7.0`: Recommended range (default: 5.0)
|
| 122 |
+
- Higher values: Stronger adherence to text prompt (may reduce diversity)
|
| 123 |
+
|
| 124 |
+
## Model Architecture
|
| 125 |
+
|
| 126 |
+
### U-Net Structure
|
| 127 |
+
```
|
| 128 |
+
Input: (batch, 1, 64, 64)
|
| 129 |
+
βββ Down Block 1: 1 β 256 channels
|
| 130 |
+
βββ Down Block 2: 256 β 512 channels
|
| 131 |
+
βββ Down Block 3: 512 β 512 channels
|
| 132 |
+
βββ Middle Block: 512 channels
|
| 133 |
+
βββ Up Block 3: 1024 β 512 channels (with skip connections)
|
| 134 |
+
βββ Up Block 2: 768 β 256 channels (with skip connections)
|
| 135 |
+
βββ Up Block 1: 512 β 1 channel (with skip connections)
|
| 136 |
+
Output: (batch, 1, 64, 64) - predicted noise
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### Text Conditioning
|
| 140 |
+
- Text prompts encoded via CLIP ViT-B/32
|
| 141 |
+
- 512-dimensional text embeddings
|
| 142 |
+
- Injected into U-Net via cross-attention
|
| 143 |
+
- Classifier-free guidance with 15% dropout during training
|
| 144 |
+
|
| 145 |
+
## Training Details
|
| 146 |
+
|
| 147 |
+
- **Framework**: PyTorch 2.0+
|
| 148 |
+
- **Hardware**: 4x NVIDIA GPUs
|
| 149 |
+
- **Training Time**: ~100 epochs
|
| 150 |
+
- **Dataset**: Google QuickDraw sketches (5 classes)
|
| 151 |
+
- **Noise Schedule**: Linear (Ξ² from 0.0001 to 0.02)
|
| 152 |
+
|
| 153 |
+
## Limitations
|
| 154 |
+
|
| 155 |
+
- Limited to 64x64 resolution
|
| 156 |
+
- Grayscale output only
|
| 157 |
+
- Best performance on simple objects from QuickDraw classes
|
| 158 |
+
- May not generalize well to complex or out-of-distribution prompts
|
| 159 |
+
|
| 160 |
+
## Citation
|
| 161 |
+
|
| 162 |
+
If you use this model, please cite:
|
| 163 |
+
|
| 164 |
+
```bibtex
|
| 165 |
+
@misc{quickdraw-text-diffusion,
|
| 166 |
+
title={Text-Conditional QuickDraw Diffusion Model},
|
| 167 |
+
author={Your Name},
|
| 168 |
+
year={2024},
|
| 169 |
+
howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}}
|
| 170 |
+
}
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
## License
|
| 174 |
+
|
| 175 |
+
MIT License
|
| 176 |
+
|
| 177 |
+
## Acknowledgments
|
| 178 |
+
|
| 179 |
+
- Google QuickDraw dataset
|
| 180 |
+
- OpenAI CLIP
|
| 181 |
+
- DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
|
| 182 |
+
- Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)
|