Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import torch | |
| import gradio as gr | |
| from model import DistMult | |
| from PIL import Image | |
| from torchvision import transforms | |
| import json | |
| from tqdm import tqdm | |
| # Default image tensor normalization | |
| _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] | |
| _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] | |
| def generate_target_list(data, entity2id): | |
| sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] | |
| sub = list(sub['t']) | |
| categories = [] | |
| for item in tqdm(sub): | |
| if entity2id[str(int(float(item)))] not in categories: | |
| categories.append(entity2id[str(int(float(item)))]) | |
| # print('categories = {}'.format(categories)) | |
| # print("No. of target categories = {}".format(len(categories))) | |
| return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) | |
| # Load necessary data and initialize the model | |
| entity2id = json.load(open('entity2id_subtree.json', 'r')) | |
| id2entity = {v: k for k, v in entity2id.items()} | |
| datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False) | |
| num_ent_id = len(entity2id) | |
| target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere | |
| overall_id_to_name = json.load(open('overall_id_to_name.json')) | |
| # Initialize your model here | |
| model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary | |
| model.eval() | |
| # Define your evaluation function | |
| def evaluate(img): | |
| transform_steps = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((448, 448)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD) | |
| ]) | |
| h = transform_steps(img) | |
| r = torch.tensor([3]) | |
| # Assuming `move_to` is a function to move tensors to the desired device | |
| h = h.unsqueeze(0) | |
| r = r.unsqueeze(0) | |
| outputs = model.forward_ce(h, r, triple_type=('image', 'id')) | |
| y_pred = outputs.argmax(-1).cpu() | |
| pred_label = target_list[y_pred].item() | |
| species_label = overall_id_to_name[str(id2entity[pred_label])] | |
| return {species_label:1.0} | |
| if __name__ == '__main__': | |
| # Gradio interface | |
| species_model = gr.Interface( | |
| evaluate, | |
| gr.inputs.Image(shape=(200, 200)), | |
| outputs="label", | |
| title='Species Classification', | |
| description='Species Classification', | |
| article='Species Classification' | |
| ) | |
| species_model.launch(share=True, debug=True) |