|
|
import torchvision.transforms as T |
|
|
|
|
|
|
|
|
|
|
|
def get_image_transform( |
|
|
image_size: int, |
|
|
center_crop: bool = False, |
|
|
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR |
|
|
): |
|
|
if center_crop: |
|
|
crop = [ |
|
|
T.Resize(image_size, interpolation=interpolation), |
|
|
T.CenterCrop(image_size) |
|
|
] |
|
|
else: |
|
|
|
|
|
crop = [ |
|
|
T.Resize((image_size, image_size), interpolation=interpolation) |
|
|
] |
|
|
|
|
|
return T.Compose(crop + [ |
|
|
T.Lambda(lambda x: x.convert("RGB")), |
|
|
T.ToTensor(), |
|
|
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torchvision.transforms as T |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
def get_image_transform( |
|
|
image_size: int, |
|
|
center_crop: bool = False, |
|
|
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR |
|
|
): |
|
|
if center_crop: |
|
|
crop = [ |
|
|
T.Resize(image_size, interpolation=interpolation), |
|
|
T.CenterCrop(image_size) |
|
|
] |
|
|
else: |
|
|
|
|
|
crop = [ |
|
|
T.Resize((image_size, image_size), interpolation=interpolation) |
|
|
] |
|
|
|
|
|
return T.Compose(crop + [ |
|
|
T.Lambda(lambda x: x.convert("RGB")), |
|
|
T.ToTensor(), |
|
|
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
|
|
|
def _convert_to_rgb(image: Image.Image) -> Image.Image: |
|
|
"""Converts a PIL Image to RGB format.""" |
|
|
return image.convert("RGB") |
|
|
|
|
|
|
|
|
def get_image_transform_fix( |
|
|
image_size: int, |
|
|
center_crop: bool = False, |
|
|
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR |
|
|
): |
|
|
if center_crop: |
|
|
crop = [ |
|
|
T.Resize(image_size, interpolation=interpolation), |
|
|
T.CenterCrop(image_size) |
|
|
] |
|
|
else: |
|
|
|
|
|
crop = [ |
|
|
T.Resize((image_size, image_size), interpolation=interpolation) |
|
|
] |
|
|
|
|
|
return T.Compose(crop + [ |
|
|
T.Lambda(_convert_to_rgb), |
|
|
T.ToTensor(), |
|
|
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
def get_text_tokenizer(context_length: int): |
|
|
return SimpleTokenizer(context_length=context_length) |