import os import cv2 import gradio as gr import mediapipe as mp import numpy as np from PIL import Image from gradio_client import Client, handle_file from mediapipe.tasks import python from mediapipe.tasks.python import vision # Paths BASE_DIR = os.path.dirname(__file__) EXAMPLE_PATH = os.path.join(BASE_DIR, "example") MODEL_PATH = os.path.join(BASE_DIR, "pose_landmarker.task") # Example images garm_list = os.listdir(os.path.join(EXAMPLE_PATH, "cloth")) garm_list_path = [os.path.join(EXAMPLE_PATH, "cloth", garm) for garm in garm_list] human_list = os.listdir(os.path.join(EXAMPLE_PATH, "human")) human_list_path = [os.path.join(EXAMPLE_PATH, "human", human) for human in human_list] # Validate model file early so the error is clear if not os.path.exists(MODEL_PATH): raise FileNotFoundError( f"No se encontrĂ³ el archivo del modelo: {MODEL_PATH}. " "Debes subir 'pose_landmarker.task' junto a app.py." ) # MediaPipe PoseLandmarker setup base_options = python.BaseOptions(model_asset_path=MODEL_PATH) options = vision.PoseLandmarkerOptions( base_options=base_options, running_mode=vision.RunningMode.IMAGE ) pose_landmarker = vision.PoseLandmarker.create_from_options(options) # Landmark indices used by MediaPipe Pose POSE_IDX = { "left_shoulder": 11, "right_shoulder": 12, "left_hip": 23, "right_hip": 24, } def get_pose_result(image_bgr): """ Receives a BGR image (OpenCV format) and returns the PoseLandmarker result. """ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_rgb) return pose_landmarker.detect(mp_image) def detect_pose(image): """ Detects pose keypoints and draws them on the image. """ result = get_pose_result(image) if result.pose_landmarks: landmarks = result.pose_landmarks[0] height, width, _ = image.shape for name, index in POSE_IDX.items(): lm = landmarks[index] x, y = int(lm.x * width), int(lm.y * height) cv2.circle(image, (x, y), 5, (0, 255, 0), -1) cv2.putText( image, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1 ) return image def align_clothing(body_img, clothing_img): """ Simple clothing warp to torso region based on pose landmarks. """ result = get_pose_result(body_img) output = body_img.copy() if result.pose_landmarks: landmarks = result.pose_landmarks[0] h, w, _ = output.shape def get_point(landmark_id): lm = landmarks[landmark_id] return int(lm.x * w), int(lm.y * h) left_shoulder = get_point(POSE_IDX["left_shoulder"]) right_shoulder = get_point(POSE_IDX["right_shoulder"]) left_hip = get_point(POSE_IDX["left_hip"]) right_hip = get_point(POSE_IDX["right_hip"]) # Destination box (torso region) dst_pts = np.array([ left_shoulder, right_shoulder, right_hip, left_hip ], dtype=np.float32) # Source box (clothing image corners) src_h, src_w = clothing_img.shape[:2] src_pts = np.array([ [0, 0], [src_w, 0], [src_w, src_h], [0, src_h] ], dtype=np.float32) # Perspective transform matrix = cv2.getPerspectiveTransform(src_pts, dst_pts) warped_clothing = cv2.warpPerspective( clothing_img, matrix, (w, h), borderMode=cv2.BORDER_TRANSPARENT ) # Alpha blending if clothing has transparency if clothing_img.shape[2] == 4: alpha = warped_clothing[:, :, 3] / 255.0 for c in range(3): output[:, :, c] = ( (1 - alpha) * output[:, :, c] + alpha * warped_clothing[:, :, c] ) else: output = cv2.addWeighted(output, 0.8, warped_clothing, 0.5, 0) return output def process_image(human_img_path, garm_img_path): client = Client("franciszzj/Leffa") result = client.predict( src_image_path=handle_file(human_img_path), ref_image_path=handle_file(garm_img_path), ref_acceleration=False, step=30, scale=2.5, seed=42, vt_model_type="viton_hd", vt_garment_type="upper_body", vt_repaint=False, api_name="/leffa_predict_vt" ) generated_image_path = result[0] generated_image = Image.open(generated_image_path) return generated_image image_blocks = gr.Blocks().queue() with image_blocks as demo: gr.HTML("

Virtual Try-On

") gr.HTML("

Upload an image of a person and an image of a garment

") with gr.Row(): with gr.Column(): human_img = gr.Image(type="filepath", label="Human", interactive=True) gr.Examples( inputs=human_img, examples_per_page=10, examples=human_list_path ) with gr.Column(): garm_img = gr.Image(label="Garment", type="filepath", interactive=True) gr.Examples( inputs=garm_img, examples_per_page=8, examples=garm_list_path ) with gr.Column(): image_out = gr.Image(label="Processed image", type="pil") with gr.Row(): try_button = gr.Button(value="Try-on", variant="primary") try_button.click(fn=process_image, inputs=[human_img, garm_img], outputs=image_out) image_blocks.launch(show_error=True)