Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from diffusers import DDPMScheduler | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps | |
| from pipeline import Zero123PlusPipeline | |
| from utils import add_white_bg, load_z123_pipe | |
| from typing import Optional | |
| class VAEProcessor: | |
| """A helper class to handle encoding and decoding images with the VAE.""" | |
| def __init__(self, pipeline: Zero123PlusPipeline): | |
| self.pipe = pipeline | |
| self.image_processor = pipeline.image_processor | |
| self.vae = pipeline.vae | |
| self.latent_shift_factor = 0.22 | |
| self.latent_scale_factor = 0.75 | |
| self.image_scale_factor = 0.5 / 0.8 | |
| def encode(self, image: Image.Image) -> torch.Tensor: | |
| """Encodes a PIL image into the latent space.""" | |
| image_tensor = self.image_processor.preprocess(image).to(self.vae.device).half() | |
| with torch.autocast("cuda"), torch.inference_mode(): | |
| image_tensor *= self.image_scale_factor | |
| denorm = self.vae.encode(image_tensor).latent_dist.mode() | |
| denorm *= self.vae.config.scaling_factor | |
| return (denorm - self.latent_shift_factor) * self.latent_scale_factor | |
| def decode(self, latents: torch.Tensor) -> Image.Image: | |
| """Decodes latents back into a post-processed image.""" | |
| with torch.autocast("cuda"), torch.inference_mode(): | |
| denorm = latents / self.latent_scale_factor + self.latent_shift_factor | |
| image = self.vae.decode(denorm / self.vae.config.scaling_factor, return_dict=False)[0] | |
| image /= self.image_scale_factor | |
| return self.image_processor.postprocess(image) | |
| class EditAwareDenoiser: | |
| """Encapsulates the entire Edit-Aware Denoising process.""" | |
| def __init__(self, pipe: Zero123PlusPipeline, scheduler: DDPMScheduler, T_steps: int, src_gs: float, tar_gs: float, n_max: int): | |
| """Initializes the denoiser with the pipeline and configuration.""" | |
| self.pipe = pipe | |
| self.scheduler = scheduler | |
| self.T_steps = T_steps | |
| self.src_guidance_scale = src_gs | |
| self.tar_guidance_scale = tar_gs | |
| self.n_max = n_max | |
| def _mix_cfg(cond: torch.Tensor, uncond: torch.Tensor, cfg: float) -> torch.Tensor: | |
| """Mixes conditional and unconditional predictions.""" | |
| return uncond + cfg * (cond - uncond) | |
| def _get_differential_edit_direction(self, t: torch.Tensor, zt_src: torch.Tensor, zt_tar: torch.Tensor) -> torch.Tensor: | |
| """Computes the differential edit direction (delta v) for a timestep.""" | |
| condition_noise = torch.randn_like(self.src_cond_lat) | |
| noisy_src_cond_lat = self.pipe.scheduler.scale_model_input( | |
| self.pipe.scheduler.add_noise(self.src_cond_lat, condition_noise, t), t | |
| ) | |
| vt_src_uncond, vt_src_cond = self._calc_v_zero(self.src_cond_img, zt_src, t, noisy_src_cond_lat) | |
| vt_src = self._mix_cfg(vt_src_cond, vt_src_uncond, self.src_guidance_scale) | |
| noisy_tar_cond_lat = self.pipe.scheduler.scale_model_input( | |
| self.pipe.scheduler.add_noise(self.tar_cond_lat, condition_noise, t), t | |
| ) | |
| vt_tar_uncond, vt_tar_cond = self._calc_v_zero(self.tar_cond_img, zt_tar, t, noisy_tar_cond_lat) | |
| vt_tar = self._mix_cfg(vt_tar_cond, vt_tar_uncond, self.tar_guidance_scale) | |
| return vt_tar - vt_src | |
| def _propagate_for_timestep(self, zt_edit: torch.Tensor, t: torch.Tensor, dt: torch.Tensor) -> torch.Tensor: | |
| """Performs a single propagation step for the edit.""" | |
| fwd_noise = torch.randn_like(self.x_src) | |
| zt_src = self.scheduler.scale_model_input(self.scheduler.add_noise(self.x_src, fwd_noise, t), t) | |
| zt_tar = self.scheduler.scale_model_input(self.scheduler.add_noise(zt_edit, fwd_noise, t), t) | |
| diff_v = self._get_differential_edit_direction(t, zt_src, zt_tar) | |
| zt_edit_change = dt * diff_v | |
| zt_edit = zt_edit.to(torch.float32) + zt_edit_change | |
| return zt_edit.to(diff_v.dtype) | |
| def _calc_v_zero(self, condition_image: Image.Image, noisy_latent: torch.Tensor, t: torch.Tensor, noised_condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Calculates the unconditional and conditional v-prediction from the UNet.""" | |
| DUMMY_GUIDANCE_SCALE = 2 | |
| model_output = {} | |
| def hook_fn(module, args, output): | |
| model_output['v_pred'] = output[0] | |
| hook_handle = self.pipe.unet.register_forward_hook(hook_fn) | |
| try: | |
| self.pipe( | |
| condition_image, | |
| latents=noisy_latent, | |
| num_inference_steps=1, | |
| guidance_scale=DUMMY_GUIDANCE_SCALE, | |
| timesteps=[t.item()], | |
| output_type="latent", | |
| noisy_cond_lat=noised_condition, | |
| ) | |
| finally: | |
| hook_handle.remove() | |
| return model_output['v_pred'].chunk(2) | |
| def denoise(self, x_src: torch.Tensor, src_cond_img: Image.Image, tar_cond_img: Image.Image) -> torch.Tensor: | |
| """Public method to run the entire denoising process.""" | |
| self.x_src = x_src | |
| self.src_cond_img = src_cond_img | |
| self.tar_cond_img = tar_cond_img | |
| timesteps, _ = retrieve_timesteps(self.scheduler, self.T_steps, self.x_src.device) | |
| zt_edit = self.x_src.clone() | |
| self.src_cond_lat = self.pipe.make_condition_lat(self.src_cond_img, guidance_scale=2.0) | |
| self.tar_cond_lat = self.pipe.make_condition_lat(self.tar_cond_img, guidance_scale=2.0) | |
| start_index = max(0, len(timesteps) - self.n_max) | |
| for i in tqdm(range(start_index, len(timesteps))): | |
| t = timesteps[i] | |
| t_i = t / 1000.0 | |
| t_im1 = timesteps[i + 1] / 1000.0 if i + 1 < len(timesteps) else torch.zeros_like(t_i) | |
| dt = t_im1 - t_i | |
| zt_edit = self._propagate_for_timestep(zt_edit, t, dt) | |
| return zt_edit | |
| def run_editp23( | |
| src_condition_path: str, | |
| tgt_condition_path: str, | |
| original_mv: str, | |
| save_path: str, | |
| device_number: int = 0, | |
| T_steps: int = 50, | |
| n_max: int = 31, | |
| src_guidance_scale: float = 3.5, | |
| tar_guidance_scale: float = 5.0, | |
| seed: int = 18, | |
| pipeline: Optional[Zero123PlusPipeline] = None, | |
| ) -> None: | |
| """Main execution function to run the complete editing pipeline.""" | |
| if pipeline is None: | |
| pipeline = load_z123_pipe(device_number) | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| vae_processor = VAEProcessor(pipeline) | |
| src_cond_img = add_white_bg(Image.open(src_condition_path)) | |
| tgt_cond_img = add_white_bg(Image.open(tgt_condition_path)) | |
| mv_src = add_white_bg(Image.open(original_mv)) | |
| x0_src = vae_processor.encode(mv_src) | |
| denoiser = EditAwareDenoiser( | |
| pipe=pipeline, | |
| scheduler=pipeline.scheduler, | |
| T_steps=T_steps, | |
| src_gs=src_guidance_scale, | |
| tar_gs=tar_guidance_scale, | |
| n_max=n_max | |
| ) | |
| x0_tar = denoiser.denoise(x0_src, src_cond_img, tgt_cond_img) | |
| image_tar = vae_processor.decode(x0_tar) | |
| image_tar[0].save(save_path) | |
| print(f"Successfully saved result to {save_path}") |