| import contextlib |
| import functools |
| import json |
| import logging |
| import os |
| import time |
| import urllib.request |
|
|
| import gradio as gr |
| import open_clip |
| import PIL.Image |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json' |
| IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg' |
|
|
|
|
| @contextlib.contextmanager |
| def timed(name): |
| t0 = time.monotonic() |
| try: |
| yield |
| finally: |
| logging.info('Timed %s: %.1f secs', name, time.monotonic() - t0) |
|
|
|
|
| @functools.cache |
| def load_model(name='hf-hub:timm/ViT-SO400M-14-SigLIP-384'): |
| with timed('loading model, preprocess, tokenizer'): |
| t0 = time.time() |
| model, preprocess = open_clip.create_model_from_pretrained(name) |
| tokenizer = open_clip.get_tokenizer(name) |
| logging.info('loaded in %.1fs', time.time() - t0) |
| return model, preprocess, tokenizer |
|
|
|
|
| def generate_answers(image_path, prompts): |
|
|
| model, preprocess, tokenizer = load_model() |
|
|
| with torch.no_grad(), torch.cuda.amp.autocast(): |
| logging.info('Opening image "%s"', image_path) |
| with timed(f'opening image "{image_path}"'): |
| image = PIL.Image.open(image_path) |
| with timed('image features'): |
| image = preprocess(image).unsqueeze(0) |
| image_features = model.encode_image(image) |
|
|
| with timed('text features'): |
| prompts = prompts.split(', ') |
| text = tokenizer(prompts, context_length=model.context_length) |
| text_features = model.encode_text(text) |
| image_features = F.normalize(image_features, dim=-1) |
| text_features = F.normalize(text_features, dim=-1) |
|
|
| exp, bias = model.logit_scale.exp(), model.logit_bias |
| text_probs = torch.sigmoid(image_features @ text_features.T * exp + bias) |
| return list(zip(prompts, [round(p.item(), 3) for p in text_probs[0]])) |
|
|
|
|
| def create_app(): |
| info = json.load(urllib.request.urlopen(INFO_URL)) |
|
|
| with gr.Blocks() as demo: |
|
|
| gr.Markdown('Minimal gradio clone of [lit-tuning-demo](https://google-research.github.io/vision_transformer/lit/)') |
| gr.Markdown('Using `open_clip` implementation of SigLIP model `timm/ViT-SO400M-14-SigLIP-384`') |
|
|
| with gr.Row(): |
| image = gr.Image(label='input_image', type='filepath') |
| with gr.Column(): |
| prompts = gr.Textbox(label='prompts') |
| answer = gr.Textbox(label='answer') |
| run = gr.Button('Run') |
|
|
| gr.Examples( |
| examples=[ |
| [IMG_URL_FMT.format(ex['id']), ex['prompts']] |
| for ex in info |
| ], |
| inputs=[image, prompts], |
| outputs=[answer], |
| ) |
|
|
| run.click(fn=generate_answers, inputs=[image, prompts], outputs=[answer]) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
|
|
| logging.basicConfig(level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| for k, v in os.environ.items(): |
| logging.info('environ["%s"] = %r', k, v) |
|
|
| _ = load_model() |
|
|
| create_app().queue().launch() |
|
|