| from utils.onnx_helpers import postprocess_onnx_output |
| |
| from utils.onnx_helpers import infer_onnx_model |
| |
| from utils.onnx_helpers import preprocess_onnx_input |
| """ |
| Model loading and registration logic for OpenSight Deepfake Detection Playground. |
| Handles ONNX, HuggingFace, and Gradio API model registration and metadata. |
| """ |
| from utils.registry import register_model, MODEL_REGISTRY, ModelEntry |
| from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache |
| from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
| import torch |
| import numpy as np |
| from PIL import Image |
|
|
|
|
| |
| MODEL_PATHS = { |
| "model_1": "LPX55/detection-model-1-ONNX", |
| "model_2": "LPX55/detection-model-2-ONNX", |
| "model_3": "LPX55/detection-model-3-ONNX", |
| "model_4": "cmckinle/sdxl-flux-detector_v1.1", |
| "model_5": "LPX55/detection-model-5-ONNX", |
| "model_6": "LPX55/detection-model-6-ONNX", |
| "model_7": "LPX55/detection-model-7-ONNX", |
| "model_8": "aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT" |
| } |
|
|
| CLASS_NAMES = { |
| "model_1": ['artificial', 'real'], |
| "model_2": ['AI Image', 'Real Image'], |
| "model_3": ['artificial', 'human'], |
| "model_4": ['AI', 'Real'], |
| "model_5": ['Realism', 'Deepfake'], |
| "model_6": ['ai_gen', 'human'], |
| "model_7": ['Fake', 'Real'], |
| "model_8": ['Fake', 'Real'], |
| } |
|
|
|
|
| |
| _onnx_model_cache = {} |
|
|
| def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None): |
| entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset) |
| MODEL_REGISTRY[model_id] = entry |
|
|
| class ONNXModelWrapper: |
| def __init__(self, hf_model_id): |
| self.hf_model_id = hf_model_id |
| self._session = None |
| self._preprocessor_config = None |
| self._model_config = None |
|
|
| def load(self): |
| if self._session is None: |
| self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache( |
| self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor |
| ) |
|
|
| def __call__(self, image_np): |
| self.load() |
| return infer_onnx_model(self.hf_model_id, image_np, self._model_config) |
|
|
| def preprocess(self, image: Image.Image): |
| self.load() |
| return preprocess_onnx_input(image, self._preprocessor_config) |
|
|
| def postprocess(self, onnx_output: dict, class_names_from_registry: list): |
| self.load() |
| return postprocess_onnx_output(onnx_output, self._model_config) |
|
|
| |
|
|
| def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output): |
| for model_key, hf_model_path in MODEL_PATHS.items(): |
| model_num = model_key.replace("model_", "").upper() |
| contributor = "Unknown" |
| architecture = "Unknown" |
| dataset = "TBA" |
| current_class_names = CLASS_NAMES.get(model_key, []) |
| if "ONNX" in hf_model_path: |
| onnx_wrapper_instance = ONNXModelWrapper(hf_model_path) |
| if model_key == "model_1": |
| contributor = "haywoodsloan" |
| architecture = "SwinV2" |
| dataset = "Mixed" |
| elif model_key == "model_2": |
| contributor = "Heem2" |
| architecture = "ViT" |
| dataset = "Mixed" |
| elif model_key == "model_3": |
| contributor = "Organika" |
| architecture = "VIT" |
| dataset = "SDXL" |
| elif model_key == "model_5": |
| contributor = "prithivMLmods" |
| architecture = "VIT" |
| elif model_key == "model_6": |
| contributor = "ideepankarsharma2003" |
| architecture = "SWINv1" |
| dataset = "SDXL, Midjourney" |
| elif model_key == "model_7": |
| contributor = "date3k2" |
| architecture = "VIT" |
| display_name_parts = [model_num] |
| if architecture and architecture not in ["Unknown"]: |
| display_name_parts.append(architecture) |
| if dataset and dataset not in ["TBA"]: |
| display_name_parts.append(dataset) |
| display_name = "-".join(display_name_parts) + "_ONNX" |
| register_model_with_metadata( |
| model_id=model_key, |
| model=onnx_wrapper_instance, |
| preprocess=onnx_wrapper_instance.preprocess, |
| postprocess=onnx_wrapper_instance.postprocess, |
| class_names=current_class_names, |
| display_name=display_name, |
| contributor=contributor, |
| model_path=hf_model_path, |
| architecture=architecture, |
| dataset=dataset |
| ) |
| elif model_key == "model_8": |
| contributor = "aiwithoutborders-xyz" |
| architecture = "ViT" |
| dataset = "Massive" |
| display_name_parts = [model_num] |
| if architecture and architecture not in ["Unknown"]: |
| display_name_parts.append(architecture) |
| if dataset and dataset not in ["TBA"]: |
| display_name_parts.append(dataset) |
| display_name = "-".join(display_name_parts) |
| register_model_with_metadata( |
| model_id=model_key, |
| model=infer_gradio_api, |
| preprocess=preprocess_gradio_api, |
| postprocess=postprocess_gradio_api, |
| class_names=current_class_names, |
| display_name=display_name, |
| contributor=contributor, |
| model_path=hf_model_path, |
| architecture=architecture, |
| dataset=dataset |
| ) |
| elif model_key == "model_4": |
| contributor = "cmckinle" |
| architecture = "VIT" |
| dataset = "SDXL, FLUX" |
| display_name_parts = [model_num] |
| if architecture and architecture not in ["Unknown"]: |
| display_name_parts.append(architecture) |
| if dataset and dataset not in ["TBA"]: |
| display_name_parts.append(dataset) |
| display_name = "-".join(display_name_parts) |
| current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device) |
| model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device) |
| preprocess_func = preprocess_resize_256 |
| postprocess_func = postprocess_logits |
| def custom_infer(image, processor_local=current_processor, model_local=model_instance): |
| inputs = processor_local(image, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = model_local(**inputs) |
| return outputs |
| model_instance = custom_infer |
| register_model_with_metadata( |
| model_id=model_key, |
| model=model_instance, |
| preprocess=preprocess_func, |
| postprocess=postprocess_func, |
| class_names=current_class_names, |
| display_name=display_name, |
| contributor=contributor, |
| model_path=hf_model_path, |
| architecture=architecture, |
| dataset=dataset |
| ) |
| else: |
| pass |
|
|