SceneWeaver / inpainting_blender.py
DawnC's picture
Upload 15 files
991a517 verified
import logging
from typing import Any, Dict, Optional, Tuple
import cv2
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class InpaintingBlender:
"""
Handles mask processing, prompt enhancement, and result blending for inpainting.
This class encapsulates all pre-processing and post-processing operations
needed for inpainting, separate from the main generation pipeline.
Attributes:
min_mask_coverage: Minimum mask coverage threshold
max_mask_coverage: Maximum mask coverage threshold
Example:
>>> blender = InpaintingBlender()
>>> processed_mask, info = blender.prepare_mask(mask, (512, 512), feather_radius=8)
>>> enhanced_prompt, negative = blender.enhance_prompt("a flower", image, mask)
>>> result = blender.blend_result(original, generated, mask)
"""
def __init__(
self,
min_mask_coverage: float = 0.01,
max_mask_coverage: float = 0.95
):
"""
Initialize the InpaintingBlender.
Parameters
----------
min_mask_coverage : float
Minimum mask coverage (default: 1%)
max_mask_coverage : float
Maximum mask coverage (default: 95%)
"""
self.min_mask_coverage = min_mask_coverage
self.max_mask_coverage = max_mask_coverage
logger.info("InpaintingBlender initialized")
def prepare_mask(
self,
mask: Image.Image,
target_size: Tuple[int, int],
feather_radius: int = 8
) -> Tuple[Image.Image, Dict[str, Any]]:
"""
Prepare and validate mask for inpainting.
Parameters
----------
mask : PIL.Image
Input mask (white = inpaint area)
target_size : tuple
Target (width, height) to match input image
feather_radius : int
Feathering radius in pixels
Returns
-------
tuple
(processed_mask, validation_info)
Raises
------
ValueError
If mask coverage is outside acceptable range
"""
# Convert to grayscale
if mask.mode != 'L':
mask = mask.convert('L')
# Resize to match target
if mask.size != target_size:
mask = mask.resize(target_size, Image.LANCZOS)
# Convert to array for processing
mask_array = np.array(mask)
# Calculate coverage
total_pixels = mask_array.size
white_pixels = np.count_nonzero(mask_array > 127)
coverage = white_pixels / total_pixels
validation_info = {
"coverage": coverage,
"white_pixels": white_pixels,
"total_pixels": total_pixels,
"feather_radius": feather_radius,
"valid": True,
"warning": ""
}
# Validate coverage
if coverage < self.min_mask_coverage:
validation_info["valid"] = False
validation_info["warning"] = (
f"Mask coverage too low ({coverage:.1%}). "
f"Please select a larger area to inpaint."
)
logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.min_mask_coverage:.1%}")
elif coverage > self.max_mask_coverage:
validation_info["valid"] = False
validation_info["warning"] = (
f"Mask coverage too high ({coverage:.1%}). "
f"Consider using background generation instead."
)
logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.max_mask_coverage:.1%}")
# Apply feathering
if feather_radius > 0:
mask_array = cv2.GaussianBlur(
mask_array,
(feather_radius * 2 + 1, feather_radius * 2 + 1),
feather_radius / 2
)
logger.debug(f"Applied {feather_radius}px feathering to mask")
processed_mask = Image.fromarray(mask_array, mode='L')
return processed_mask, validation_info
def enhance_prompt_for_inpainting(
self,
prompt: str,
image: Image.Image,
mask: Image.Image
) -> Tuple[str, str]:
"""
Enhance prompt based on non-masked region analysis.
Analyzes the surrounding context to generate appropriate
lighting and color descriptors.
Parameters
----------
prompt : str
User-provided prompt
image : PIL.Image
Original image
mask : PIL.Image
Inpainting mask
Returns
-------
tuple
(enhanced_prompt, negative_prompt)
"""
logger.info("Enhancing prompt for inpainting context...")
# Convert to arrays
img_array = np.array(image.convert('RGB'))
mask_array = np.array(mask.convert('L'))
# Analyze non-masked regions
non_masked = mask_array < 127
if not np.any(non_masked):
# No context available
enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
negative_prompt = self._get_inpainting_negative_prompt()
return enhanced_prompt, negative_prompt
# Extract context pixels
context_pixels = img_array[non_masked]
# Convert to Lab for analysis
context_lab = cv2.cvtColor(
context_pixels.reshape(-1, 1, 3),
cv2.COLOR_RGB2LAB
).reshape(-1, 3)
# Use robust statistics (median) to avoid outlier influence
median_l = np.median(context_lab[:, 0])
median_b = np.median(context_lab[:, 2])
# Analyze lighting conditions
lighting_descriptors = []
if median_l > 170:
lighting_descriptors.append("bright")
elif median_l > 130:
lighting_descriptors.append("well-lit")
elif median_l > 80:
lighting_descriptors.append("moderate lighting")
else:
lighting_descriptors.append("dim lighting")
# Analyze color temperature (b channel: blue(-) to yellow(+))
if median_b > 140:
lighting_descriptors.append("warm golden tones")
elif median_b > 120:
lighting_descriptors.append("warm afternoon light")
elif median_b < 110:
lighting_descriptors.append("cool neutral tones")
# Calculate saturation from context
hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
median_saturation = np.median(hsv[:, :, 1])
if median_saturation > 150:
lighting_descriptors.append("vibrant colors")
elif median_saturation < 80:
lighting_descriptors.append("subtle muted colors")
# Build enhanced prompt
lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
quality_suffix = "high quality, detailed, photorealistic, seamless integration"
if lighting_desc:
enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
else:
enhanced_prompt = f"{prompt}, {quality_suffix}"
negative_prompt = self._get_inpainting_negative_prompt()
logger.info(f"Enhanced prompt with context: {lighting_desc}")
return enhanced_prompt, negative_prompt
def _get_inpainting_negative_prompt(self) -> str:
"""Get standard negative prompt for inpainting."""
return (
"inconsistent lighting, wrong perspective, mismatched colors, "
"visible seams, blending artifacts, color bleeding, "
"blurry, low quality, distorted, deformed, "
"harsh edges, unnatural transition"
)
def blend_result(
self,
original: Image.Image,
generated: Image.Image,
mask: Image.Image
) -> Image.Image:
"""
Blend generated content with original image.
Uses color matching and linear color space blending for seamless results.
Parameters
----------
original : PIL.Image
Original image
generated : PIL.Image
Generated inpainted image
mask : PIL.Image
Blending mask (white = use generated)
Returns
-------
PIL.Image
Blended result
"""
logger.info("Blending inpainting result with color matching...")
# Ensure same size
if generated.size != original.size:
generated = generated.resize(original.size, Image.LANCZOS)
if mask.size != original.size:
mask = mask.resize(original.size, Image.LANCZOS)
# Convert to arrays
orig_array = np.array(original.convert('RGB')).astype(np.float32)
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
# Apply color matching to generated region (use original mask for accurate boundary detection)
gen_array = self._match_colors_at_boundary(orig_array, gen_array, mask_array)
# Create blend mask: soften edges ONLY for blending (not for generation)
# This ensures full generation coverage while smooth blending at edges
blend_mask = self._create_blend_mask(mask_array)
# sRGB to linear conversion
def srgb_to_linear(img: np.ndarray) -> np.ndarray:
img_norm = img / 255.0
return np.where(
img_norm <= 0.04045,
img_norm / 12.92,
np.power((img_norm + 0.055) / 1.055, 2.4)
)
def linear_to_srgb(img: np.ndarray) -> np.ndarray:
img_clipped = np.clip(img, 0, 1)
return np.where(
img_clipped <= 0.0031308,
12.92 * img_clipped,
1.055 * np.power(img_clipped, 1/2.4) - 0.055
)
# Convert to linear space
orig_linear = srgb_to_linear(orig_array)
gen_linear = srgb_to_linear(gen_array)
# Alpha blending in linear space using the blend mask (with softened edges)
alpha = blend_mask[:, :, np.newaxis]
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
# Convert back to sRGB
result_srgb = linear_to_srgb(result_linear)
result_array = (result_srgb * 255).astype(np.uint8)
logger.debug("Blending completed with color matching")
return Image.fromarray(result_array)
def _match_colors_at_boundary(
self,
original: np.ndarray,
generated: np.ndarray,
mask: np.ndarray
) -> np.ndarray:
"""
Match colors of generated content to original at the boundary.
Uses histogram matching in Lab color space for natural blending.
Parameters
----------
original : np.ndarray
Original image array (float32, 0-255)
generated : np.ndarray
Generated image array (float32, 0-255)
mask : np.ndarray
Mask array (float32, 0-1)
Returns
-------
np.ndarray
Color-matched generated image
"""
# Create boundary region mask (dilated mask - eroded mask)
mask_binary = (mask > 0.5).astype(np.uint8) * 255
# Create narrow boundary region for sampling original colors
kernel_size = 25 # Pixels to sample around boundary
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
dilated = cv2.dilate(mask_binary, kernel, iterations=1)
eroded = cv2.erode(mask_binary, kernel, iterations=1)
# Outer boundary (original side)
outer_boundary = (dilated > 0) & (mask_binary == 0)
# Inner boundary (generated side)
inner_boundary = (mask_binary > 0) & (eroded == 0)
if not np.any(outer_boundary) or not np.any(inner_boundary):
logger.debug("No boundary region found, skipping color matching")
return generated
# Convert to Lab color space
orig_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
gen_lab = cv2.cvtColor(generated.astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
# Sample colors from boundary regions
orig_boundary_pixels = orig_lab[outer_boundary]
gen_boundary_pixels = gen_lab[inner_boundary]
if len(orig_boundary_pixels) < 10 or len(gen_boundary_pixels) < 10:
logger.debug("Not enough boundary pixels, skipping color matching")
return generated
# Calculate statistics
orig_mean = np.mean(orig_boundary_pixels, axis=0)
orig_std = np.std(orig_boundary_pixels, axis=0) + 1e-6
gen_mean = np.mean(gen_boundary_pixels, axis=0)
gen_std = np.std(gen_boundary_pixels, axis=0) + 1e-6
# Calculate correction factors
# Only correct L (lightness) and a,b (color) channels
l_correction = (orig_mean[0] - gen_mean[0]) * 0.7 # 70% correction for lightness
a_correction = (orig_mean[1] - gen_mean[1]) * 0.5 # 50% correction for color
b_correction = (orig_mean[2] - gen_mean[2]) * 0.5
logger.debug(f"Color correction: L={l_correction:.1f}, a={a_correction:.1f}, b={b_correction:.1f}")
# Apply correction to masked region only
corrected_lab = gen_lab.copy()
mask_region = mask > 0.3 # Apply to most of masked region
corrected_lab[mask_region, 0] = np.clip(
corrected_lab[mask_region, 0] + l_correction, 0, 255
)
corrected_lab[mask_region, 1] = np.clip(
corrected_lab[mask_region, 1] + a_correction, 0, 255
)
corrected_lab[mask_region, 2] = np.clip(
corrected_lab[mask_region, 2] + b_correction, 0, 255
)
# Convert back to RGB
corrected_rgb = cv2.cvtColor(
corrected_lab.astype(np.uint8),
cv2.COLOR_LAB2RGB
).astype(np.float32)
logger.info("Applied boundary color matching")
return corrected_rgb
def _create_blend_mask(self, mask: np.ndarray) -> np.ndarray:
"""
Create a blend mask with softened edges for natural compositing.
The mask interior stays fully opaque (1.0) while only the edges
get a smooth transition. This preserves full generated content
while blending naturally at boundaries.
Parameters
----------
mask : np.ndarray
Original mask array (float32, 0-1)
Returns
-------
np.ndarray
Blend mask with soft edges but solid interior
"""
# Convert to uint8 for morphological operations
mask_uint8 = (mask * 255).astype(np.uint8)
# Create eroded version (solid interior)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
eroded = cv2.erode(mask_uint8, kernel, iterations=1)
# Create smooth transition zone at edges only
# Blur the original mask for edge softness
blurred = cv2.GaussianBlur(mask_uint8, (15, 15), 4)
# Combine: use eroded (solid) for interior, blurred for edges
# Where eroded > 0, use full opacity; elsewhere use blurred transition
result = np.where(eroded > 128, mask_uint8, blurred)
# Final light smoothing
result = cv2.GaussianBlur(result, (5, 5), 1)
# Convert back to float
blend_mask = result.astype(np.float32) / 255.0
logger.debug("Created blend mask with soft edges and solid interior")
return blend_mask
def validate_inputs(
self,
image: Image.Image,
mask: Image.Image
) -> Tuple[bool, str]:
"""
Validate image and mask inputs before processing.
Parameters
----------
image : PIL.Image
Input image
mask : PIL.Image
Input mask
Returns
-------
tuple
(is_valid, error_message)
"""
if image is None:
return False, "No image provided"
if mask is None:
return False, "No mask provided"
# Check sizes match
if image.size != mask.size:
# Will be resized later, so just log a warning
logger.warning(f"Image size {image.size} != mask size {mask.size}, will resize")
return True, ""