Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from typing import Any, Dict, Optional, Tuple | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class InpaintingBlender: | |
| """ | |
| Handles mask processing, prompt enhancement, and result blending for inpainting. | |
| This class encapsulates all pre-processing and post-processing operations | |
| needed for inpainting, separate from the main generation pipeline. | |
| Attributes: | |
| min_mask_coverage: Minimum mask coverage threshold | |
| max_mask_coverage: Maximum mask coverage threshold | |
| Example: | |
| >>> blender = InpaintingBlender() | |
| >>> processed_mask, info = blender.prepare_mask(mask, (512, 512), feather_radius=8) | |
| >>> enhanced_prompt, negative = blender.enhance_prompt("a flower", image, mask) | |
| >>> result = blender.blend_result(original, generated, mask) | |
| """ | |
| def __init__( | |
| self, | |
| min_mask_coverage: float = 0.01, | |
| max_mask_coverage: float = 0.95 | |
| ): | |
| """ | |
| Initialize the InpaintingBlender. | |
| Parameters | |
| ---------- | |
| min_mask_coverage : float | |
| Minimum mask coverage (default: 1%) | |
| max_mask_coverage : float | |
| Maximum mask coverage (default: 95%) | |
| """ | |
| self.min_mask_coverage = min_mask_coverage | |
| self.max_mask_coverage = max_mask_coverage | |
| logger.info("InpaintingBlender initialized") | |
| def prepare_mask( | |
| self, | |
| mask: Image.Image, | |
| target_size: Tuple[int, int], | |
| feather_radius: int = 8 | |
| ) -> Tuple[Image.Image, Dict[str, Any]]: | |
| """ | |
| Prepare and validate mask for inpainting. | |
| Parameters | |
| ---------- | |
| mask : PIL.Image | |
| Input mask (white = inpaint area) | |
| target_size : tuple | |
| Target (width, height) to match input image | |
| feather_radius : int | |
| Feathering radius in pixels | |
| Returns | |
| ------- | |
| tuple | |
| (processed_mask, validation_info) | |
| Raises | |
| ------ | |
| ValueError | |
| If mask coverage is outside acceptable range | |
| """ | |
| # Convert to grayscale | |
| if mask.mode != 'L': | |
| mask = mask.convert('L') | |
| # Resize to match target | |
| if mask.size != target_size: | |
| mask = mask.resize(target_size, Image.LANCZOS) | |
| # Convert to array for processing | |
| mask_array = np.array(mask) | |
| # Calculate coverage | |
| total_pixels = mask_array.size | |
| white_pixels = np.count_nonzero(mask_array > 127) | |
| coverage = white_pixels / total_pixels | |
| validation_info = { | |
| "coverage": coverage, | |
| "white_pixels": white_pixels, | |
| "total_pixels": total_pixels, | |
| "feather_radius": feather_radius, | |
| "valid": True, | |
| "warning": "" | |
| } | |
| # Validate coverage | |
| if coverage < self.min_mask_coverage: | |
| validation_info["valid"] = False | |
| validation_info["warning"] = ( | |
| f"Mask coverage too low ({coverage:.1%}). " | |
| f"Please select a larger area to inpaint." | |
| ) | |
| logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.min_mask_coverage:.1%}") | |
| elif coverage > self.max_mask_coverage: | |
| validation_info["valid"] = False | |
| validation_info["warning"] = ( | |
| f"Mask coverage too high ({coverage:.1%}). " | |
| f"Consider using background generation instead." | |
| ) | |
| logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.max_mask_coverage:.1%}") | |
| # Apply feathering | |
| if feather_radius > 0: | |
| mask_array = cv2.GaussianBlur( | |
| mask_array, | |
| (feather_radius * 2 + 1, feather_radius * 2 + 1), | |
| feather_radius / 2 | |
| ) | |
| logger.debug(f"Applied {feather_radius}px feathering to mask") | |
| processed_mask = Image.fromarray(mask_array, mode='L') | |
| return processed_mask, validation_info | |
| def enhance_prompt_for_inpainting( | |
| self, | |
| prompt: str, | |
| image: Image.Image, | |
| mask: Image.Image | |
| ) -> Tuple[str, str]: | |
| """ | |
| Enhance prompt based on non-masked region analysis. | |
| Analyzes the surrounding context to generate appropriate | |
| lighting and color descriptors. | |
| Parameters | |
| ---------- | |
| prompt : str | |
| User-provided prompt | |
| image : PIL.Image | |
| Original image | |
| mask : PIL.Image | |
| Inpainting mask | |
| Returns | |
| ------- | |
| tuple | |
| (enhanced_prompt, negative_prompt) | |
| """ | |
| logger.info("Enhancing prompt for inpainting context...") | |
| # Convert to arrays | |
| img_array = np.array(image.convert('RGB')) | |
| mask_array = np.array(mask.convert('L')) | |
| # Analyze non-masked regions | |
| non_masked = mask_array < 127 | |
| if not np.any(non_masked): | |
| # No context available | |
| enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic" | |
| negative_prompt = self._get_inpainting_negative_prompt() | |
| return enhanced_prompt, negative_prompt | |
| # Extract context pixels | |
| context_pixels = img_array[non_masked] | |
| # Convert to Lab for analysis | |
| context_lab = cv2.cvtColor( | |
| context_pixels.reshape(-1, 1, 3), | |
| cv2.COLOR_RGB2LAB | |
| ).reshape(-1, 3) | |
| # Use robust statistics (median) to avoid outlier influence | |
| median_l = np.median(context_lab[:, 0]) | |
| median_b = np.median(context_lab[:, 2]) | |
| # Analyze lighting conditions | |
| lighting_descriptors = [] | |
| if median_l > 170: | |
| lighting_descriptors.append("bright") | |
| elif median_l > 130: | |
| lighting_descriptors.append("well-lit") | |
| elif median_l > 80: | |
| lighting_descriptors.append("moderate lighting") | |
| else: | |
| lighting_descriptors.append("dim lighting") | |
| # Analyze color temperature (b channel: blue(-) to yellow(+)) | |
| if median_b > 140: | |
| lighting_descriptors.append("warm golden tones") | |
| elif median_b > 120: | |
| lighting_descriptors.append("warm afternoon light") | |
| elif median_b < 110: | |
| lighting_descriptors.append("cool neutral tones") | |
| # Calculate saturation from context | |
| hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV) | |
| median_saturation = np.median(hsv[:, :, 1]) | |
| if median_saturation > 150: | |
| lighting_descriptors.append("vibrant colors") | |
| elif median_saturation < 80: | |
| lighting_descriptors.append("subtle muted colors") | |
| # Build enhanced prompt | |
| lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else "" | |
| quality_suffix = "high quality, detailed, photorealistic, seamless integration" | |
| if lighting_desc: | |
| enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}" | |
| else: | |
| enhanced_prompt = f"{prompt}, {quality_suffix}" | |
| negative_prompt = self._get_inpainting_negative_prompt() | |
| logger.info(f"Enhanced prompt with context: {lighting_desc}") | |
| return enhanced_prompt, negative_prompt | |
| def _get_inpainting_negative_prompt(self) -> str: | |
| """Get standard negative prompt for inpainting.""" | |
| return ( | |
| "inconsistent lighting, wrong perspective, mismatched colors, " | |
| "visible seams, blending artifacts, color bleeding, " | |
| "blurry, low quality, distorted, deformed, " | |
| "harsh edges, unnatural transition" | |
| ) | |
| def blend_result( | |
| self, | |
| original: Image.Image, | |
| generated: Image.Image, | |
| mask: Image.Image | |
| ) -> Image.Image: | |
| """ | |
| Blend generated content with original image. | |
| Uses color matching and linear color space blending for seamless results. | |
| Parameters | |
| ---------- | |
| original : PIL.Image | |
| Original image | |
| generated : PIL.Image | |
| Generated inpainted image | |
| mask : PIL.Image | |
| Blending mask (white = use generated) | |
| Returns | |
| ------- | |
| PIL.Image | |
| Blended result | |
| """ | |
| logger.info("Blending inpainting result with color matching...") | |
| # Ensure same size | |
| if generated.size != original.size: | |
| generated = generated.resize(original.size, Image.LANCZOS) | |
| if mask.size != original.size: | |
| mask = mask.resize(original.size, Image.LANCZOS) | |
| # Convert to arrays | |
| orig_array = np.array(original.convert('RGB')).astype(np.float32) | |
| gen_array = np.array(generated.convert('RGB')).astype(np.float32) | |
| mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0 | |
| # Apply color matching to generated region (use original mask for accurate boundary detection) | |
| gen_array = self._match_colors_at_boundary(orig_array, gen_array, mask_array) | |
| # Create blend mask: soften edges ONLY for blending (not for generation) | |
| # This ensures full generation coverage while smooth blending at edges | |
| blend_mask = self._create_blend_mask(mask_array) | |
| # sRGB to linear conversion | |
| def srgb_to_linear(img: np.ndarray) -> np.ndarray: | |
| img_norm = img / 255.0 | |
| return np.where( | |
| img_norm <= 0.04045, | |
| img_norm / 12.92, | |
| np.power((img_norm + 0.055) / 1.055, 2.4) | |
| ) | |
| def linear_to_srgb(img: np.ndarray) -> np.ndarray: | |
| img_clipped = np.clip(img, 0, 1) | |
| return np.where( | |
| img_clipped <= 0.0031308, | |
| 12.92 * img_clipped, | |
| 1.055 * np.power(img_clipped, 1/2.4) - 0.055 | |
| ) | |
| # Convert to linear space | |
| orig_linear = srgb_to_linear(orig_array) | |
| gen_linear = srgb_to_linear(gen_array) | |
| # Alpha blending in linear space using the blend mask (with softened edges) | |
| alpha = blend_mask[:, :, np.newaxis] | |
| result_linear = gen_linear * alpha + orig_linear * (1 - alpha) | |
| # Convert back to sRGB | |
| result_srgb = linear_to_srgb(result_linear) | |
| result_array = (result_srgb * 255).astype(np.uint8) | |
| logger.debug("Blending completed with color matching") | |
| return Image.fromarray(result_array) | |
| def _match_colors_at_boundary( | |
| self, | |
| original: np.ndarray, | |
| generated: np.ndarray, | |
| mask: np.ndarray | |
| ) -> np.ndarray: | |
| """ | |
| Match colors of generated content to original at the boundary. | |
| Uses histogram matching in Lab color space for natural blending. | |
| Parameters | |
| ---------- | |
| original : np.ndarray | |
| Original image array (float32, 0-255) | |
| generated : np.ndarray | |
| Generated image array (float32, 0-255) | |
| mask : np.ndarray | |
| Mask array (float32, 0-1) | |
| Returns | |
| ------- | |
| np.ndarray | |
| Color-matched generated image | |
| """ | |
| # Create boundary region mask (dilated mask - eroded mask) | |
| mask_binary = (mask > 0.5).astype(np.uint8) * 255 | |
| # Create narrow boundary region for sampling original colors | |
| kernel_size = 25 # Pixels to sample around boundary | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) | |
| dilated = cv2.dilate(mask_binary, kernel, iterations=1) | |
| eroded = cv2.erode(mask_binary, kernel, iterations=1) | |
| # Outer boundary (original side) | |
| outer_boundary = (dilated > 0) & (mask_binary == 0) | |
| # Inner boundary (generated side) | |
| inner_boundary = (mask_binary > 0) & (eroded == 0) | |
| if not np.any(outer_boundary) or not np.any(inner_boundary): | |
| logger.debug("No boundary region found, skipping color matching") | |
| return generated | |
| # Convert to Lab color space | |
| orig_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32) | |
| gen_lab = cv2.cvtColor(generated.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32) | |
| # Sample colors from boundary regions | |
| orig_boundary_pixels = orig_lab[outer_boundary] | |
| gen_boundary_pixels = gen_lab[inner_boundary] | |
| if len(orig_boundary_pixels) < 10 or len(gen_boundary_pixels) < 10: | |
| logger.debug("Not enough boundary pixels, skipping color matching") | |
| return generated | |
| # Calculate statistics | |
| orig_mean = np.mean(orig_boundary_pixels, axis=0) | |
| orig_std = np.std(orig_boundary_pixels, axis=0) + 1e-6 | |
| gen_mean = np.mean(gen_boundary_pixels, axis=0) | |
| gen_std = np.std(gen_boundary_pixels, axis=0) + 1e-6 | |
| # Calculate correction factors | |
| # Only correct L (lightness) and a,b (color) channels | |
| l_correction = (orig_mean[0] - gen_mean[0]) * 0.7 # 70% correction for lightness | |
| a_correction = (orig_mean[1] - gen_mean[1]) * 0.5 # 50% correction for color | |
| b_correction = (orig_mean[2] - gen_mean[2]) * 0.5 | |
| logger.debug(f"Color correction: L={l_correction:.1f}, a={a_correction:.1f}, b={b_correction:.1f}") | |
| # Apply correction to masked region only | |
| corrected_lab = gen_lab.copy() | |
| mask_region = mask > 0.3 # Apply to most of masked region | |
| corrected_lab[mask_region, 0] = np.clip( | |
| corrected_lab[mask_region, 0] + l_correction, 0, 255 | |
| ) | |
| corrected_lab[mask_region, 1] = np.clip( | |
| corrected_lab[mask_region, 1] + a_correction, 0, 255 | |
| ) | |
| corrected_lab[mask_region, 2] = np.clip( | |
| corrected_lab[mask_region, 2] + b_correction, 0, 255 | |
| ) | |
| # Convert back to RGB | |
| corrected_rgb = cv2.cvtColor( | |
| corrected_lab.astype(np.uint8), | |
| cv2.COLOR_LAB2RGB | |
| ).astype(np.float32) | |
| logger.info("Applied boundary color matching") | |
| return corrected_rgb | |
| def _create_blend_mask(self, mask: np.ndarray) -> np.ndarray: | |
| """ | |
| Create a blend mask with softened edges for natural compositing. | |
| The mask interior stays fully opaque (1.0) while only the edges | |
| get a smooth transition. This preserves full generated content | |
| while blending naturally at boundaries. | |
| Parameters | |
| ---------- | |
| mask : np.ndarray | |
| Original mask array (float32, 0-1) | |
| Returns | |
| ------- | |
| np.ndarray | |
| Blend mask with soft edges but solid interior | |
| """ | |
| # Convert to uint8 for morphological operations | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| # Create eroded version (solid interior) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) | |
| eroded = cv2.erode(mask_uint8, kernel, iterations=1) | |
| # Create smooth transition zone at edges only | |
| # Blur the original mask for edge softness | |
| blurred = cv2.GaussianBlur(mask_uint8, (15, 15), 4) | |
| # Combine: use eroded (solid) for interior, blurred for edges | |
| # Where eroded > 0, use full opacity; elsewhere use blurred transition | |
| result = np.where(eroded > 128, mask_uint8, blurred) | |
| # Final light smoothing | |
| result = cv2.GaussianBlur(result, (5, 5), 1) | |
| # Convert back to float | |
| blend_mask = result.astype(np.float32) / 255.0 | |
| logger.debug("Created blend mask with soft edges and solid interior") | |
| return blend_mask | |
| def validate_inputs( | |
| self, | |
| image: Image.Image, | |
| mask: Image.Image | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Validate image and mask inputs before processing. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Input image | |
| mask : PIL.Image | |
| Input mask | |
| Returns | |
| ------- | |
| tuple | |
| (is_valid, error_message) | |
| """ | |
| if image is None: | |
| return False, "No image provided" | |
| if mask is None: | |
| return False, "No mask provided" | |
| # Check sizes match | |
| if image.size != mask.size: | |
| # Will be resized later, so just log a warning | |
| logger.warning(f"Image size {image.size} != mask size {mask.size}, will resize") | |
| return True, "" | |