jamesaasher commited on
Commit
231be4e
Β·
verified Β·
1 Parent(s): 37dc324

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +182 -0
README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - diffusion
5
+ - text-to-image
6
+ - quickdraw
7
+ - pytorch
8
+ - clip
9
+ - ddpm
10
+ language:
11
+ - en
12
+ datasets:
13
+ - Xenova/quickdraw-small
14
+ ---
15
+
16
+ # Text-Conditional QuickDraw Diffusion Model
17
+
18
+ A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches.
19
+
20
+ ## Model Description
21
+
22
+ This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses:
23
+ - **CLIP text encoder** (`openai/clip-vit-base-patch32`) for text conditioning
24
+ - **DDPM** for the diffusion process (1000 timesteps)
25
+ - **Classifier-free guidance** for improved text-image alignment
26
+ - Trained on **Google QuickDraw** dataset
27
+
28
+ ## Model Details
29
+
30
+ - **Model Type**: Text-conditional DDPM diffusion model
31
+ - **Architecture**: U-Net with cross-attention for text conditioning
32
+ - **Image Size**: 64x64 grayscale
33
+ - **Base Channels**: 256
34
+ - **Text Encoder**: CLIP ViT-B/32 (frozen)
35
+ - **Training Steps**: 100 epochs
36
+ - **Diffusion Timesteps**: 1000
37
+ - **Guidance Scale**: 5.0 (default)
38
+
39
+ ### Training Configuration
40
+
41
+ - **Dataset**: Xenova/quickdraw-small (5 classes)
42
+ - **Batch Size**: 128 (32 per GPU Γ— 4 GPUs)
43
+ - **Learning Rate**: 1e-4
44
+ - **CFG Drop Probability**: 0.15
45
+ - **Optimizer**: Adam
46
+
47
+ ## Usage
48
+
49
+ ### Installation
50
+
51
+ ```bash
52
+ pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm
53
+ ```
54
+
55
+ ### Generate Images
56
+
57
+ ```python
58
+ import torch
59
+ from model import TextConditionedUNet
60
+ from scheduler import SimpleDDPMScheduler
61
+ from text_encoder import CLIPTextEncoder
62
+ from generate import generate_samples
63
+
64
+ # Load checkpoint
65
+ checkpoint_path = "text_diffusion_final_epoch_100.pt"
66
+ checkpoint = torch.load(checkpoint_path)
67
+
68
+ # Initialize model
69
+ model = TextConditionedUNet(text_dim=512).cuda()
70
+ model.load_state_dict(checkpoint['model_state_dict'])
71
+ model.eval()
72
+
73
+ # Initialize text encoder
74
+ text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda()
75
+ text_encoder.eval()
76
+
77
+ # Generate samples
78
+ scheduler = SimpleDDPMScheduler(1000)
79
+ prompt = "a drawing of a cat"
80
+ num_samples = 4
81
+ guidance_scale = 5.0
82
+
83
+ with torch.no_grad():
84
+ text_embedding = text_encoder(prompt)
85
+ text_embeddings = text_embedding.repeat(num_samples, 1)
86
+
87
+ shape = (num_samples, 1, 64, 64)
88
+ samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale)
89
+ ```
90
+
91
+ ### Command Line Usage
92
+
93
+ ```bash
94
+ # Generate samples
95
+ python generate.py --checkpoint text_diffusion_final_epoch_100.pt \
96
+ --prompt "a drawing of a fire truck" \
97
+ --num-samples 4 \
98
+ --guidance-scale 5.0
99
+
100
+ # Visualize denoising process
101
+ python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \
102
+ --prompt "a drawing of a cat" \
103
+ --num-steps 10
104
+ ```
105
+
106
+ ## Example Prompts
107
+
108
+ Try these prompts for best results:
109
+ - "a drawing of a cat"
110
+ - "a drawing of a fire truck"
111
+ - "a drawing of an airplane"
112
+ - "a drawing of a house"
113
+ - "a drawing of a tree"
114
+
115
+ **Note**: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]".
116
+
117
+ ## Classifier-Free Guidance
118
+
119
+ The model supports classifier-free guidance to improve text-image alignment:
120
+ - `guidance_scale = 1.0`: No guidance (pure conditional generation)
121
+ - `guidance_scale = 3.0-7.0`: Recommended range (default: 5.0)
122
+ - Higher values: Stronger adherence to text prompt (may reduce diversity)
123
+
124
+ ## Model Architecture
125
+
126
+ ### U-Net Structure
127
+ ```
128
+ Input: (batch, 1, 64, 64)
129
+ β”œβ”€β”€ Down Block 1: 1 β†’ 256 channels
130
+ β”œβ”€β”€ Down Block 2: 256 β†’ 512 channels
131
+ β”œβ”€β”€ Down Block 3: 512 β†’ 512 channels
132
+ β”œβ”€β”€ Middle Block: 512 channels
133
+ β”œβ”€β”€ Up Block 3: 1024 β†’ 512 channels (with skip connections)
134
+ β”œβ”€β”€ Up Block 2: 768 β†’ 256 channels (with skip connections)
135
+ └── Up Block 1: 512 β†’ 1 channel (with skip connections)
136
+ Output: (batch, 1, 64, 64) - predicted noise
137
+ ```
138
+
139
+ ### Text Conditioning
140
+ - Text prompts encoded via CLIP ViT-B/32
141
+ - 512-dimensional text embeddings
142
+ - Injected into U-Net via cross-attention
143
+ - Classifier-free guidance with 15% dropout during training
144
+
145
+ ## Training Details
146
+
147
+ - **Framework**: PyTorch 2.0+
148
+ - **Hardware**: 4x NVIDIA GPUs
149
+ - **Training Time**: ~100 epochs
150
+ - **Dataset**: Google QuickDraw sketches (5 classes)
151
+ - **Noise Schedule**: Linear (Ξ² from 0.0001 to 0.02)
152
+
153
+ ## Limitations
154
+
155
+ - Limited to 64x64 resolution
156
+ - Grayscale output only
157
+ - Best performance on simple objects from QuickDraw classes
158
+ - May not generalize well to complex or out-of-distribution prompts
159
+
160
+ ## Citation
161
+
162
+ If you use this model, please cite:
163
+
164
+ ```bibtex
165
+ @misc{quickdraw-text-diffusion,
166
+ title={Text-Conditional QuickDraw Diffusion Model},
167
+ author={Your Name},
168
+ year={2024},
169
+ howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}}
170
+ }
171
+ ```
172
+
173
+ ## License
174
+
175
+ MIT License
176
+
177
+ ## Acknowledgments
178
+
179
+ - Google QuickDraw dataset
180
+ - OpenAI CLIP
181
+ - DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
182
+ - Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)