Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel | |
| # Pattern to ignore all the text after 2 or more full stops | |
| regex_pattern = "[.]{2,}" | |
| def post_process(text): | |
| try: | |
| text = text.strip() | |
| text = re.split(regex_pattern, text)[0] | |
| except Exception as e: | |
| print(e) | |
| pass | |
| return text | |
| def set_example_image(example: list): | |
| return example[0] | |
| def predict(image, max_length=64, num_beams=4): | |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| pixel_values, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| return_dict_in_generate=True, | |
| ).sequences | |
| preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| pred = post_process(preds[0]) | |
| return pred | |
| model_name_or_path = "deepklarity/poster2plot" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load model. | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path) | |
| model.to(device) | |
| print("Loaded model") | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path) | |
| print("Loaded feature_extractor") | |
| tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True) | |
| if model.decoder.name_or_path == "gpt2": | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loaded tokenizer") | |
| examples = [[f"examples/{filename}"] for filename in next(os.walk('examples'), (None, None, []))[2]] | |
| print(f"Loaded {len(examples)} example images") | |
| with gr.Blocks() as poster2plot: | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("# Poster2Plot: Upload a Movie/T.V show poster to generate a plot") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_image = gr.Image(label='Input Image', type='numpy') | |
| with gr.Row(): | |
| submit_button = gr.Button(value="Submit", variant='primary') | |
| with gr.Column(): | |
| plot = gr.Textbox(label="Plot") | |
| with gr.Row(): | |
| example_images = gr.Dataset(components=[input_image], samples=examples) | |
| with gr.Row(): | |
| gr.Markdown("Made by: [dk-crazydiv](https://twitter.com/kartik_godawat) and [dsr](https://twitter.com/dsr_ai)") | |
| submit_button.click(fn=predict, inputs=[input_image], outputs=[plot]) | |
| example_images.click(fn=set_example_image, inputs=[example_images], outputs=[input_image]) | |
| poster2plot.launch() | |