Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import pickle | |
| from time import perf_counter | |
| import tempfile | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler | |
| from utils.drag import bi_warp | |
| __all__ = [ | |
| 'clear_all', 'resize', | |
| 'visualize_user_drag', 'preview_out_image', 'inpaint', | |
| 'add_point', 'undo_point', 'clear_point', | |
| ] | |
| # Global variables for lazy loading | |
| pipe = None | |
| # UI functions | |
| def clear_all(length): | |
| """Reset UI by clearing all input images and parameters.""" | |
| return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 5, None) | |
| def resize(canvas, gen_length, canvas_length): | |
| """Resize canvas while maintaining aspect ratio.""" | |
| if not canvas: | |
| return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3 | |
| result = process_canvas(canvas) | |
| if result[0] is None: # Check if image is None | |
| return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3 | |
| image = result[0] | |
| aspect_ratio = image.shape[1] / image.shape[0] | |
| is_landscape = aspect_ratio >= 1 | |
| new_dims = ( | |
| (gen_length, round(gen_length / aspect_ratio / 8) * 8) if is_landscape | |
| else (round(gen_length * aspect_ratio / 8) * 8, gen_length) | |
| ) | |
| canvas_dims = ( | |
| (canvas_length, round(canvas_length / aspect_ratio)) if is_landscape | |
| else (round(canvas_length * aspect_ratio), canvas_length) | |
| ) | |
| return (gr.Image(value=cv2.resize(image, new_dims), width=canvas_dims[0], height=canvas_dims[1]),) * 3 | |
| def process_canvas(canvas): | |
| """Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object.""" | |
| # Handle None canvas | |
| if canvas is None: | |
| return None, None | |
| # Handle new ImageEditor format | |
| if isinstance(canvas, dict): | |
| if 'background' in canvas and 'layers' in canvas: | |
| # New ImageEditor format | |
| if canvas["background"] is None: | |
| return None, None | |
| image = canvas["background"].copy() | |
| # Ensure image is 3-channel RGB | |
| if len(image.shape) == 3 and image.shape[2] == 4: | |
| image = image[:, :, :3] # Remove alpha channel | |
| elif len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| # Try to extract mask from layers | |
| mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| if canvas["layers"]: | |
| for layer in canvas["layers"]: | |
| if isinstance(layer, np.ndarray) and len(layer.shape) >= 2: | |
| layer_mask = np.uint8(layer[:, :, 0] > 0) if len(layer.shape) == 3 else np.uint8(layer > 0) | |
| mask = np.logical_or(mask, layer_mask).astype(np.uint8) | |
| elif 'image' in canvas and 'mask' in canvas: | |
| # Old format | |
| if canvas["image"] is None: | |
| return None, None | |
| image = canvas["image"].copy() | |
| # Ensure image is 3-channel RGB | |
| if len(image.shape) == 3 and image.shape[2] == 4: | |
| image = image[:, :, :3] # Remove alpha channel | |
| elif len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy() if canvas["mask"] is not None else np.zeros(image.shape[:2], dtype=np.uint8) | |
| else: | |
| # Fallback | |
| return None, None | |
| else: | |
| # Direct numpy array | |
| if canvas is None: | |
| return None, None | |
| image = canvas.copy() if isinstance(canvas, np.ndarray) else np.array(canvas) | |
| # Ensure image is 3-channel RGB | |
| if len(image.shape) == 3 and image.shape[2] == 4: | |
| image = image[:, :, :3] # Remove alpha channel | |
| elif len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| return image, mask | |
| # Point manipulation functions | |
| def add_point(canvas, points, inpaint_ks, evt: gr.SelectData): | |
| """Add selected point to points list and update image.""" | |
| if canvas is None: | |
| return None | |
| points.append(evt.index) | |
| return visualize_user_drag(canvas, points) | |
| def undo_point(canvas, points, inpaint_ks): | |
| """Remove last point and update image.""" | |
| if canvas is None: | |
| return None | |
| if len(points) > 0: | |
| points.pop() | |
| return visualize_user_drag(canvas, points) | |
| def clear_point(canvas, points, inpaint_ks): | |
| """Clear all points and update image.""" | |
| if canvas is None: | |
| return None | |
| points.clear() | |
| return visualize_user_drag(canvas, points) | |
| # Visualization tools | |
| def visualize_user_drag(canvas, points): | |
| """Visualize control points and motion vectors on the input image.""" | |
| if canvas is None: | |
| return None | |
| result = process_canvas(canvas) | |
| if result[0] is None: # Check if image is None | |
| return None | |
| image, mask = result | |
| # Ensure image is uint8 and 3-channel | |
| if image.dtype != np.uint8: | |
| image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) | |
| if len(image.shape) != 3 or image.shape[2] != 3: | |
| return None | |
| # Apply colored mask overlay | |
| result_img = image.copy() | |
| if np.any(mask == 1): | |
| result_img[mask == 1] = [255, 0, 0] # Red color | |
| image = cv2.addWeighted(result_img, 0.3, image, 0.7, 0) | |
| # Draw mask outline | |
| if np.any(mask > 0): | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(image, contours, -1, (255, 255, 255), 2) | |
| # Draw control points and motion vectors | |
| prev_point = None | |
| for idx, point in enumerate(points, 1): | |
| if idx % 2 == 0: | |
| cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point | |
| if prev_point is not None: | |
| cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5) | |
| else: | |
| cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point | |
| prev_point = point | |
| return image | |
| def preview_out_image(canvas, points, inpaint_ks): | |
| """Preview warped image result and generate inpainting mask.""" | |
| if canvas is None: | |
| return None, None | |
| result = process_canvas(canvas) | |
| if result[0] is None: # Check if image is None | |
| return None, None | |
| image, mask = result | |
| # Ensure image is uint8 and 3-channel | |
| if image.dtype != np.uint8: | |
| image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) | |
| if len(image.shape) != 3 or image.shape[2] != 3: | |
| return image, None | |
| if len(points) < 2: | |
| return image, None | |
| # ensure H, W divisible by 8 and longer edge 512 | |
| shapes_valid = all(s % 8 == 0 for s in mask.shape + image.shape[:2]) | |
| size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask)) | |
| if not (shapes_valid and size_valid): | |
| gr.Warning('Click Resize Image Button first.') | |
| return image, None | |
| try: | |
| handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks) | |
| image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]] | |
| # Add grid pattern to highlight inpainting regions | |
| background = np.ones_like(mask) * 255 | |
| background[::10] = background[:, ::10] = 0 | |
| image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image) | |
| return image, (inpaint_mask * 255).astype(np.uint8) | |
| except Exception as e: | |
| gr.Warning(f"Preview failed: {str(e)}") | |
| return image, None | |
| # Inpaint tools | |
| def setup_pipeline(device='cuda', model_version='v1-5'): | |
| """Initialize optimized inpainting pipeline with specified model configuration.""" | |
| MODEL_CONFIGS = { | |
| 'v1-5': ('runwayml/stable-diffusion-inpainting', 'latent-consistency/lcm-lora-sdv1-5', 'madebyollin/taesd'), | |
| 'xl': ('diffusers/stable-diffusion-xl-1.0-inpainting-0.1', 'latent-consistency/lcm-lora-sdxl', 'madebyollin/taesdxl') | |
| } | |
| model_id, lora_id, vae_id = MODEL_CONFIGS[model_version] | |
| # Check if CUDA is available, fallback to CPU | |
| if not torch.cuda.is_available(): | |
| device = 'cpu' | |
| torch_dtype = torch.float32 | |
| variant = None | |
| else: | |
| torch_dtype = torch.float16 | |
| variant = "fp16" | |
| gr.Info('Loading inpainting pipeline...') | |
| pipe = AutoPipelineForInpainting.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| variant=variant, | |
| safety_checker=None | |
| ) | |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
| pipe.load_lora_weights(lora_id) | |
| pipe.fuse_lora() | |
| pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch_dtype) | |
| pipe = pipe.to(device) | |
| # Pre-compute prompt embeddings during setup | |
| if model_version == 'v1-5': | |
| pipe.cached_prompt_embeds = pipe.encode_prompt( | |
| '', device=device, num_images_per_prompt=1, | |
| do_classifier_free_guidance=False)[0] | |
| else: | |
| pipe.cached_prompt_embeds, pipe.cached_pooled_prompt_embeds = pipe.encode_prompt( | |
| '', device=device, num_images_per_prompt=1, | |
| do_classifier_free_guidance=False)[0::2] | |
| return pipe | |
| def get_pipeline(): | |
| """Lazy load pipeline only when needed.""" | |
| global pipe | |
| if pipe is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| pipe = setup_pipeline(device=device, model_version='v1-5') | |
| if device == 'cuda': | |
| pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0] | |
| else: | |
| pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cpu', 1, False)[0] | |
| return pipe | |
| def inpaint(image, inpaint_mask): | |
| """Perform efficient inpainting on masked regions using Stable Diffusion.""" | |
| if image is None: | |
| return None | |
| if inpaint_mask is None: | |
| return image | |
| start = perf_counter() | |
| # Get pipeline (lazy loading) | |
| pipe = get_pipeline() | |
| pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5' | |
| inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0 | |
| # Convert inputs to PIL | |
| image_pil = Image.fromarray(image) | |
| inpaint_mask_pil = Image.fromarray(inpaint_mask) | |
| width, height = inpaint_mask_pil.size | |
| if width % 8 != 0 or height % 8 != 0: | |
| width, height = round(width / 8) * 8, round(height / 8) * 8 | |
| image_pil = image_pil.resize((width, height)) | |
| image = np.array(image_pil) | |
| inpaint_mask_pil = inpaint_mask_pil.resize((width, height), Image.NEAREST) | |
| inpaint_mask = np.array(inpaint_mask_pil) | |
| # Common pipeline parameters | |
| common_params = { | |
| 'image': image_pil, | |
| 'mask_image': inpaint_mask_pil, | |
| 'height': height, | |
| 'width': width, | |
| 'guidance_scale': 1.0, | |
| 'num_inference_steps': 8, | |
| 'strength': inpaint_strength, | |
| 'output_type': 'np' | |
| } | |
| # Run pipeline | |
| try: | |
| if pipe_id == 'v1-5': | |
| inpainted = pipe( | |
| prompt_embeds=pipe.cached_prompt_embeds, | |
| **common_params | |
| ).images[0] | |
| else: | |
| inpainted = pipe( | |
| prompt_embeds=pipe.cached_prompt_embeds, | |
| pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds, | |
| **common_params | |
| ).images[0] | |
| except Exception as e: | |
| gr.Warning(f"Inpainting failed: {str(e)}") | |
| return image | |
| # Post-process results | |
| inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8) | |
| return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask) |