import os import subprocess import tempfile import time from pathlib import Path import cv2 import gradio as gr import spaces from huggingface_hub import hf_hub_download from core.test_xportrait import run_inference SPACE_ID = os.getenv("SPACE_ID", "") is_shared_ui = "fffiloni/X-Portrait" in SPACE_ID CHECKPOINT_DIR = Path("checkpoint") CHECKPOINT_DIR.mkdir(exist_ok=True) MODEL_PATH = hf_hub_download( repo_id="fffiloni/X-Portrait", filename="model_state-415001.th", local_dir=str(CHECKPOINT_DIR), ) MODEL_CONFIG = "config/cldm_v15_appearance_pose_local_mm.yaml" def trim_video_ffmpeg(video_path, max_duration=2): tmp_dir = Path(tempfile.mkdtemp(prefix="xportrait_trim_")) output_path = tmp_dir / "trimmed.mp4" cmd = [ "ffmpeg", "-y", "-i", video_path, "-t", str(max_duration), "-c:v", "libx264", "-preset", "veryfast", "-c:a", "aac", str(output_path), ] subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return str(output_path) def sample_preview_frames(video_path, max_frames=12): tmp_dir = Path(tempfile.mkdtemp(prefix="xportrait_preview_")) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Cannot open video file: {video_path}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 if total_frames <= 0: cap.release() return [], [] step = max(1, total_frames // max_frames) frame_data = [] frame_indices = [] for i in range(max_frames): frame_idx = min(i * step, total_frames - 1) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if not ret: break label = f"{frame_idx:04d}" frame_path = tmp_dir / f"frame_{label}.jpg" cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) frame_data.append((str(frame_path), label)) frame_indices.append(frame_idx) cap.release() return frame_data, frame_indices def load_driving_video(video_path): if not video_path: return None, [], [], gr.update(open=False), gr.update(value=-1) processed_video = trim_video_ffmpeg(video_path, 2) if is_shared_ui else video_path frames_data, frame_indices = sample_preview_frames(processed_video, max_frames=12) return ( processed_video, frames_data, frame_indices, gr.update(open=True), gr.update(value=-1), ) def on_select_frame(evt: gr.SelectData, preview_frame_indices): if not preview_frame_indices: return gr.update() selected_gallery_index = evt.index if isinstance(selected_gallery_index, tuple): selected_gallery_index = selected_gallery_index[0] if not isinstance(selected_gallery_index, int): return gr.update() if selected_gallery_index < 0 or selected_gallery_index >= len(preview_frame_indices): return gr.update() return int(preview_frame_indices[selected_gallery_index]) def convert_video_to_h264_aac_ffmpeg(video_path): input_path = Path(video_path) output_path = input_path.with_name(f"{input_path.stem}_converted.mp4") cmd = [ "ffmpeg", "-y", "-i", str(input_path), "-c:v", "libx264", "-preset", "veryfast", "-pix_fmt", "yuv420p", "-c:a", "aac", "-movflags", "+faststart", str(output_path), ] subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return str(output_path) @spaces.GPU(duration=180) def run_xportrait( source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps, progress=gr.Progress(track_tqdm=True), ): start_time = time.perf_counter() if not source_image or not driving_video: return "Please provide a source image and a driving video.", None seed = int(seed) best_frame = int(best_frame) out_frames = int(out_frames) num_mix = int(num_mix) ddim_steps = int(ddim_steps) uc_scale = float(uc_scale) if is_shared_ui: ddim_steps = min(ddim_steps, 16) num_mix = min(num_mix, 2) if out_frames <= 0: out_frames = 16 else: out_frames = min(out_frames, 16) skip = 2 best_frame_search_stride = 2 else: skip = 1 best_frame_search_stride = 1 output_dir = Path(tempfile.mkdtemp(prefix="xportrait_out_")) try: result_path = run_inference( model_config=MODEL_CONFIG, output_dir=str(output_dir), resume_dir=MODEL_PATH, seed=seed, uc_scale=uc_scale, source_image=source_image, driving_video=driving_video, best_frame=best_frame, out_frames=out_frames, num_mix=num_mix, ddim_steps=ddim_steps, skip=skip, target_resolution=512, use_fp16=True, compile_model=False, num_drivings=16, eta=0.0, best_frame_search_stride=best_frame_search_stride, ) final_vid = convert_video_to_h264_aac_ffmpeg(result_path) return f"Output video saved at: {final_vid}", final_vid except Exception as e: return f"An error occurred: {e}", None finally: elapsed = time.perf_counter() - start_time print(f"[LOG] Execution time: {elapsed:.2f} seconds") css = """ div#frames-gallery{ overflow: scroll!important; } """ example_frame_data, example_frame_indices = sample_preview_frames("./assets/driving_video.mp4", max_frames=12) with gr.Blocks() as demo: preview_frame_indices_state = gr.State(value=example_frame_indices) with gr.Column(elem_id="col-container"): gr.Markdown("# X-Portrait: Expressive Portrait Animation with Hierarchical Motion Attention") gr.Markdown( "On this shared UI, driving video input is trimmed to 2 seconds and a faster preset is used to stay within ZeroGPU limits. Duplicate this Space for full controls." ) gr.HTML( """
""" ) with gr.Row(): with gr.Column(): with gr.Row(): source_image = gr.Image(label="Source Image", type="filepath") driving_video = gr.Video(label="Driving Video") with gr.Group(): with gr.Row(): best_frame = gr.Number( value=-1, label="Best Frame", info="Click a frame below to auto-fill this value, or set it manually. Use -1 for automatic detection.", ) out_frames = gr.Number( value=-1, label="Out Frames", info="Number of generation frames", ) with gr.Accordion("Driving video Frames", open=False) as frames_gallery_panel: driving_frames = gr.Gallery( value=example_frame_data, show_label=True, columns=6, height=380, elem_id="frames-gallery", ) with gr.Row(): seed = gr.Number(value=999, label="Seed") uc_scale = gr.Number(value=5, label="UC Scale") with gr.Row(): num_mix = gr.Number(value=2 if is_shared_ui else 4, label="Number of Mix") ddim_steps = gr.Number(value=16 if is_shared_ui else 30, label="DDIM Steps") submit_btn = gr.Button("Submit") with gr.Column(): video_output = gr.Video(label="Output Video") status = gr.Textbox(label="Status") gr.Examples( examples=[ ["./assets/source_image.png", "./assets/driving_video.mp4", "./assets/inference_result.mp4"], ], inputs=[source_image, driving_video, video_output], ) gr.HTML( """
\"Duplicate \"Follow
""" ) driving_video.upload( fn=load_driving_video, inputs=[driving_video], outputs=[driving_video, driving_frames, preview_frame_indices_state, frames_gallery_panel, best_frame], queue=False, ) driving_frames.select( fn=on_select_frame, inputs=[preview_frame_indices_state], outputs=[best_frame], queue=False, ) submit_btn.click( fn=run_xportrait, inputs=[source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps], outputs=[status, video_output], concurrency_limit=1, concurrency_id="gpu_queue", show_progress="minimal", ) demo.queue(default_concurrency_limit=1) demo.launch(ssr_mode=False, css=css)