quickdraw-text-diffusion / text_encoder.py
jamesaasher's picture
Upload text_encoder.py with huggingface_hub
2ccfe3a verified
"""CLIP Text Encoder for text-conditional diffusion."""
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer
class CLIPTextEncoder(nn.Module):
"""Wrapper around CLIP text encoder for diffusion conditioning.
Clip effectively maps images and text to the same latent space.
"""
def __init__(self, model_name="openai/clip-vit-base-patch32", freeze=True):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_model = CLIPTextModel.from_pretrained(model_name)
if freeze:
for param in self.text_model.parameters():
param.requires_grad = False
self.embedding_dim = self.text_model.config.hidden_size # 512 for base model
def forward(self, text_prompts):
"""
Encode text prompts to embeddings.
Args:
text_prompts: List of strings or single string
Returns:
Text embeddings of shape [batch_size, embedding_dim]
"""
if isinstance(text_prompts, str):
text_prompts = [text_prompts]
tokens = self.tokenizer(
text_prompts,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt"
).to(self.text_model.device)
with torch.set_grad_enabled(self.text_model.training):
outputs = self.text_model(**tokens)
embeddings = outputs.pooler_output # [batch_size, 512]
return embeddings
def encode_batch(self, text_prompts):
"""Convenience method for batch encoding."""
return self.forward(text_prompts)
@property
def device(self):
return self.text_model.device