|
|
"""CLIP Text Encoder for text-conditional diffusion.""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
|
|
|
class CLIPTextEncoder(nn.Module): |
|
|
"""Wrapper around CLIP text encoder for diffusion conditioning. |
|
|
|
|
|
Clip effectively maps images and text to the same latent space. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, model_name="openai/clip-vit-base-patch32", freeze=True): |
|
|
super().__init__() |
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(model_name) |
|
|
self.text_model = CLIPTextModel.from_pretrained(model_name) |
|
|
|
|
|
if freeze: |
|
|
for param in self.text_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.embedding_dim = self.text_model.config.hidden_size |
|
|
|
|
|
def forward(self, text_prompts): |
|
|
""" |
|
|
Encode text prompts to embeddings. |
|
|
|
|
|
Args: |
|
|
text_prompts: List of strings or single string |
|
|
|
|
|
Returns: |
|
|
Text embeddings of shape [batch_size, embedding_dim] |
|
|
""" |
|
|
if isinstance(text_prompts, str): |
|
|
text_prompts = [text_prompts] |
|
|
|
|
|
tokens = self.tokenizer( |
|
|
text_prompts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
return_tensors="pt" |
|
|
).to(self.text_model.device) |
|
|
|
|
|
with torch.set_grad_enabled(self.text_model.training): |
|
|
outputs = self.text_model(**tokens) |
|
|
embeddings = outputs.pooler_output |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def encode_batch(self, text_prompts): |
|
|
"""Convenience method for batch encoding.""" |
|
|
return self.forward(text_prompts) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.text_model.device |
|
|
|
|
|
|