Spaces:
Runtime error
Runtime error
| import copy | |
| from typing import Any, Callable, Dict, Iterable, Union | |
| import PIL | |
| import cv2 | |
| import torch | |
| import argparse | |
| import datetime | |
| import logging | |
| import inspect | |
| import math | |
| import os | |
| import shutil | |
| from typing import Dict, List, Optional, Tuple | |
| from pprint import pformat, pprint | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| import gc | |
| import time | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from omegaconf import SCMode | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from einops import rearrange, repeat | |
| import pandas as pd | |
| import h5py | |
| from diffusers.models.autoencoder_kl import AutoencoderKL | |
| from diffusers.models.modeling_utils import load_state_dict | |
| from diffusers.utils import ( | |
| logging, | |
| BaseOutput, | |
| logging, | |
| ) | |
| from diffusers.utils.dummy_pt_objects import ConsistencyDecoderVAE | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from mmcm.utils.seed_util import set_all_seed | |
| from mmcm.vision.data.video_dataset import DecordVideoDataset | |
| from mmcm.vision.process.correct_color import hist_match_video_bcthw | |
| from mmcm.vision.process.image_process import ( | |
| batch_dynamic_crop_resize_images, | |
| batch_dynamic_crop_resize_images_v2, | |
| ) | |
| from mmcm.vision.utils.data_type_util import is_video | |
| from mmcm.vision.feature_extractor.controlnet import load_controlnet_model | |
| from ..schedulers import ( | |
| EulerDiscreteScheduler, | |
| LCMScheduler, | |
| DDIMScheduler, | |
| DDPMScheduler, | |
| ) | |
| from ..models.unet_3d_condition import UNet3DConditionModel | |
| from .pipeline_controlnet import ( | |
| MusevControlNetPipeline, | |
| VideoPipelineOutput as PipelineVideoPipelineOutput, | |
| ) | |
| from ..utils.util import save_videos_grid_with_opencv | |
| from ..utils.model_util import ( | |
| update_pipeline_basemodel, | |
| update_pipeline_lora_model, | |
| update_pipeline_lora_models, | |
| update_pipeline_model_parameters, | |
| ) | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class VideoPipelineOutput(BaseOutput): | |
| videos: Union[torch.Tensor, np.ndarray] | |
| latents: Union[torch.Tensor, np.ndarray] | |
| videos_mid: Union[torch.Tensor, np.ndarray] | |
| controlnet_cond: Union[torch.Tensor, np.ndarray] | |
| generated_videos: Union[torch.Tensor, np.ndarray] | |
| def update_controlnet_processor_params( | |
| src: Union[Dict, List[Dict]], dst: Union[Dict, List[Dict]] | |
| ): | |
| """merge dst into src""" | |
| if isinstance(src, list) and not isinstance(dst, List): | |
| dst = [dst] * len(src) | |
| if isinstance(src, list) and isinstance(dst, list): | |
| return [ | |
| update_controlnet_processor_params(src[i], dst[i]) for i in range(len(src)) | |
| ] | |
| if src is None: | |
| dct = {} | |
| else: | |
| dct = copy.deepcopy(src) | |
| if dst is None: | |
| dst = {} | |
| dct.update(dst) | |
| return dct | |
| class DiffusersPipelinePredictor(object): | |
| """wraper of diffusers pipeline, support generation function interface. support | |
| 1. text2video: inputs include text, image(optional), refer_image(optional) | |
| 2. video2video: | |
| 1. use controlnet to control spatial | |
| 2. or use video fuse noise to denoise | |
| """ | |
| def __init__( | |
| self, | |
| sd_model_path: str, | |
| unet: nn.Module, | |
| controlnet_name: Union[str, List[str]] = None, | |
| controlnet: nn.Module = None, | |
| lora_dict: Dict[str, Dict] = None, | |
| requires_safety_checker: bool = False, | |
| device: str = "cuda", | |
| dtype: torch.dtype = torch.float16, | |
| # controlnet parameters start | |
| need_controlnet_processor: bool = True, | |
| need_controlnet: bool = True, | |
| image_resolution: int = 512, | |
| detect_resolution: int = 512, | |
| include_body: bool = True, | |
| hand_and_face: bool = None, | |
| include_face: bool = False, | |
| include_hand: bool = True, | |
| negative_embedding: List = None, | |
| # controlnet parameters end | |
| enable_xformers_memory_efficient_attention: bool = True, | |
| lcm_lora_dct: Dict = None, | |
| referencenet: nn.Module = None, | |
| ip_adapter_image_proj: nn.Module = None, | |
| vision_clip_extractor: nn.Module = None, | |
| face_emb_extractor: nn.Module = None, | |
| facein_image_proj: nn.Module = None, | |
| ip_adapter_face_emb_extractor: nn.Module = None, | |
| ip_adapter_face_image_proj: nn.Module = None, | |
| vae_model: Optional[Tuple[nn.Module, str]] = None, | |
| pose_guider: Optional[nn.Module] = None, | |
| enable_zero_snr: bool = False, | |
| ) -> None: | |
| self.sd_model_path = sd_model_path | |
| self.unet = unet | |
| self.controlnet_name = controlnet_name | |
| self.controlnet = controlnet | |
| self.requires_safety_checker = requires_safety_checker | |
| self.device = device | |
| self.dtype = dtype | |
| self.need_controlnet_processor = need_controlnet_processor | |
| self.need_controlnet = need_controlnet | |
| self.need_controlnet_processor = need_controlnet_processor | |
| self.image_resolution = image_resolution | |
| self.detect_resolution = detect_resolution | |
| self.include_body = include_body | |
| self.hand_and_face = hand_and_face | |
| self.include_face = include_face | |
| self.include_hand = include_hand | |
| self.negative_embedding = negative_embedding | |
| self.device = device | |
| self.dtype = dtype | |
| self.lcm_lora_dct = lcm_lora_dct | |
| if controlnet is None and controlnet_name is not None: | |
| controlnet, controlnet_processor, processor_params = load_controlnet_model( | |
| controlnet_name, | |
| device=device, | |
| dtype=dtype, | |
| need_controlnet_processor=need_controlnet_processor, | |
| need_controlnet=need_controlnet, | |
| image_resolution=image_resolution, | |
| detect_resolution=detect_resolution, | |
| include_body=include_body, | |
| include_face=include_face, | |
| hand_and_face=hand_and_face, | |
| include_hand=include_hand, | |
| ) | |
| self.controlnet_processor = controlnet_processor | |
| self.controlnet_processor_params = processor_params | |
| logger.debug(f"init controlnet controlnet_name={controlnet_name}") | |
| if controlnet is not None: | |
| controlnet = controlnet.to(device=device, dtype=dtype) | |
| controlnet.eval() | |
| if pose_guider is not None: | |
| pose_guider = pose_guider.to(device=device, dtype=dtype) | |
| pose_guider.eval() | |
| unet.to(device=device, dtype=dtype) | |
| unet.eval() | |
| if referencenet is not None: | |
| referencenet.to(device=device, dtype=dtype) | |
| referencenet.eval() | |
| if ip_adapter_image_proj is not None: | |
| ip_adapter_image_proj.to(device=device, dtype=dtype) | |
| ip_adapter_image_proj.eval() | |
| if vision_clip_extractor is not None: | |
| vision_clip_extractor.to(device=device, dtype=dtype) | |
| vision_clip_extractor.eval() | |
| if face_emb_extractor is not None: | |
| face_emb_extractor.to(device=device, dtype=dtype) | |
| face_emb_extractor.eval() | |
| if facein_image_proj is not None: | |
| facein_image_proj.to(device=device, dtype=dtype) | |
| facein_image_proj.eval() | |
| if isinstance(vae_model, str): | |
| # TODO: poor implementation, to improve | |
| if "consistency" in vae_model: | |
| vae = ConsistencyDecoderVAE.from_pretrained(vae_model) | |
| else: | |
| vae = AutoencoderKL.from_pretrained(vae_model) | |
| elif isinstance(vae_model, nn.Module): | |
| vae = vae_model | |
| else: | |
| vae = None | |
| if vae is not None: | |
| vae.to(device=device, dtype=dtype) | |
| vae.eval() | |
| if ip_adapter_face_emb_extractor is not None: | |
| ip_adapter_face_emb_extractor.to(device=device, dtype=dtype) | |
| ip_adapter_face_emb_extractor.eval() | |
| if ip_adapter_face_image_proj is not None: | |
| ip_adapter_face_image_proj.to(device=device, dtype=dtype) | |
| ip_adapter_face_image_proj.eval() | |
| params = { | |
| "pretrained_model_name_or_path": sd_model_path, | |
| "controlnet": controlnet, | |
| "unet": unet, | |
| "requires_safety_checker": requires_safety_checker, | |
| "torch_dtype": dtype, | |
| "torch_device": device, | |
| "referencenet": referencenet, | |
| "ip_adapter_image_proj": ip_adapter_image_proj, | |
| "vision_clip_extractor": vision_clip_extractor, | |
| "facein_image_proj": facein_image_proj, | |
| "face_emb_extractor": face_emb_extractor, | |
| "ip_adapter_face_emb_extractor": ip_adapter_face_emb_extractor, | |
| "ip_adapter_face_image_proj": ip_adapter_face_image_proj, | |
| "pose_guider": pose_guider, | |
| } | |
| if vae is not None: | |
| params["vae"] = vae | |
| pipeline = MusevControlNetPipeline.from_pretrained(**params) | |
| pipeline = pipeline.to(torch_device=device, torch_dtype=dtype) | |
| logger.debug( | |
| f"init pipeline from sd_model_path={sd_model_path}, device={device}, dtype={dtype}" | |
| ) | |
| if ( | |
| negative_embedding is not None | |
| and pipeline.text_encoder is not None | |
| and pipeline.tokenizer is not None | |
| ): | |
| for neg_emb_path, neg_token in negative_embedding: | |
| pipeline.load_textual_inversion(neg_emb_path, token=neg_token) | |
| # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| # pipe.enable_model_cpu_offload() | |
| if not enable_zero_snr: | |
| pipeline.scheduler = EulerDiscreteScheduler.from_config( | |
| pipeline.scheduler.config | |
| ) | |
| # pipeline.scheduler = DDIMScheduler.from_config( | |
| # pipeline.scheduler.config, | |
| # 该部分会影响生成视频的亮度,不适用于首帧给定的视频生成 | |
| # this part will change brightness of video, not suitable for image2video mode | |
| # rescale_betas_zero_snr affect the brightness of the generated video, not suitable for vision condition images mode | |
| # # rescale_betas_zero_snr=True, | |
| # ) | |
| # pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) | |
| else: | |
| # moore scheduler, just for codetest | |
| pipeline.scheduler = DDIMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="linear", | |
| clip_sample=False, | |
| steps_offset=1, | |
| ### Zero-SNR params | |
| prediction_type="v_prediction", | |
| rescale_betas_zero_snr=True, | |
| timestep_spacing="trailing", | |
| ) | |
| pipeline.enable_vae_slicing() | |
| self.enable_xformers_memory_efficient_attention = ( | |
| enable_xformers_memory_efficient_attention | |
| ) | |
| if enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| pipeline.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError( | |
| "xformers is not available. Make sure it is installed correctly" | |
| ) | |
| self.pipeline = pipeline | |
| self.unload_dict = [] # keep lora state | |
| if lora_dict is not None: | |
| self.load_lora(lora_dict=lora_dict) | |
| logger.debug("load lora {}".format(" ".join(list(lora_dict.keys())))) | |
| if lcm_lora_dct is not None: | |
| self.pipeline.scheduler = LCMScheduler.from_config( | |
| self.pipeline.scheduler.config | |
| ) | |
| self.load_lora(lora_dict=lcm_lora_dct) | |
| logger.debug("load lcm lora {}".format(" ".join(list(lcm_lora_dct.keys())))) | |
| # logger.debug("Unet3Model Parameters") | |
| # logger.debug(pformat(self.__dict__)) | |
| def load_lora( | |
| self, | |
| lora_dict: Dict[str, Dict], | |
| ): | |
| self.pipeline, unload_dict = update_pipeline_lora_models( | |
| self.pipeline, lora_dict, device=self.device | |
| ) | |
| self.unload_dict += unload_dict | |
| def unload_lora(self): | |
| for layer_data in self.unload_dict: | |
| layer = layer_data["layer"] | |
| added_weight = layer_data["added_weight"] | |
| layer.weight.data -= added_weight | |
| self.unload_dict = [] | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def update_unet(self, unet: nn.Module): | |
| self.pipeline.unet = unet.to(device=self.device, dtype=self.dtype) | |
| def update_sd_model(self, model_path: str, text_model_path: str): | |
| self.pipeline = update_pipeline_basemodel( | |
| self.pipeline, | |
| model_path, | |
| text_sd_model_path=text_model_path, | |
| device=self.device, | |
| ) | |
| def update_sd_model_and_unet( | |
| self, lora_sd_path: str, lora_path: str, sd_model_path: str = None | |
| ): | |
| self.pipeline = update_pipeline_model_parameters( | |
| self.pipeline, | |
| model_path=lora_sd_path, | |
| lora_path=lora_path, | |
| text_model_path=sd_model_path, | |
| device=self.device, | |
| ) | |
| def update_controlnet(self, controlnet_name=Union[str, List[str]]): | |
| self.pipeline.controlnet = load_controlnet_model(controlnet_name).to( | |
| device=self.device, dtype=self.dtype | |
| ) | |
| def run_pipe_text2video( | |
| self, | |
| video_length: int, | |
| prompt: Union[str, List[str]] = None, | |
| # b c t h w | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| video_num_inference_steps: int = 50, | |
| video_guidance_scale: float = 7.5, | |
| video_guidance_scale_end: float = 3.5, | |
| video_guidance_scale_method: str = "linear", | |
| strength: float = 0.8, | |
| video_negative_prompt: Optional[Union[str, List[str]]] = None, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| same_seed: Optional[Union[int, List[int]]] = None, | |
| # b c t(1) ho wo | |
| condition_latents: Optional[torch.FloatTensor] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| guidance_scale: float = 7.5, | |
| num_inference_steps: int = 50, | |
| output_type: Optional[str] = "tensor", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: int = 1, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| need_middle_latents: bool = False, | |
| w_ind_noise: float = 0.5, | |
| initial_common_latent: Optional[torch.FloatTensor] = None, | |
| latent_index: torch.LongTensor = None, | |
| vision_condition_latent_index: torch.LongTensor = None, | |
| n_vision_condition: int = 1, | |
| noise_type: str = "random", | |
| max_batch_num: int = 30, | |
| need_img_based_video_noise: bool = False, | |
| condition_images: torch.Tensor = None, | |
| fix_condition_images: bool = False, | |
| redraw_condition_image: bool = False, | |
| img_weight: float = 1e-3, | |
| motion_speed: float = 8.0, | |
| need_hist_match: bool = False, | |
| refer_image: Optional[ | |
| Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]] | |
| ] = None, | |
| ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| fixed_refer_image: bool = True, | |
| fixed_ip_adapter_image: bool = True, | |
| redraw_condition_image_with_ipdapter: bool = True, | |
| redraw_condition_image_with_referencenet: bool = True, | |
| refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| fixed_refer_face_image: bool = True, | |
| redraw_condition_image_with_facein: bool = True, | |
| ip_adapter_scale: float = 1.0, | |
| redraw_condition_image_with_ip_adapter_face: bool = True, | |
| facein_scale: float = 1.0, | |
| ip_adapter_face_scale: float = 1.0, | |
| prompt_only_use_image_prompt: bool = False, | |
| # serial_denoise parameter start | |
| record_mid_video_noises: bool = False, | |
| record_mid_video_latents: bool = False, | |
| video_overlap: int = 1, | |
| # serial_denoise parameter end | |
| # parallel_denoise parameter start | |
| context_schedule="uniform", | |
| context_frames=12, | |
| context_stride=1, | |
| context_overlap=4, | |
| context_batch_size=1, | |
| interpolation_factor=1, | |
| # parallel_denoise parameter end | |
| ): | |
| """ | |
| generate long video with end2end mode | |
| 1. prepare vision condition image by assingning, redraw, or generation with text2image module with skip_temporal_layer=True; | |
| 2. use image or latest of vision condition image to generate first shot; | |
| 3. use last n (1) image or last latent of last shot as new vision condition latent to generate next shot | |
| 4. repeat n_batch times between 2 and 3 | |
| 类似img2img pipeline | |
| refer_image和ip_adapter_image的来源: | |
| 1. 输入给定; | |
| 2. 当未输入时,纯text2video生成首帧,并赋值更新refer_image和ip_adapter_image; | |
| 3. 当有输入,但是因为redraw更新了首帧时,也需要赋值更新refer_image和ip_adapter_image; | |
| refer_image和ip_adapter_image的作用: | |
| 1. 当无首帧图像时,用于生成首帧; | |
| 2. 用于生成视频。 | |
| similar to diffusers img2img pipeline. | |
| three ways to prepare refer_image and ip_adapter_image | |
| 1. from input parameter | |
| 2. when input paramter is None, use text2video to generate vis cond image, and use as refer_image and ip_adapter_image too. | |
| 3. given from input paramter, but still redraw, update with redrawn vis cond image. | |
| """ | |
| # crop resize images | |
| if condition_images is not None: | |
| logger.debug( | |
| f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}" | |
| ) | |
| condition_images = batch_dynamic_crop_resize_images_v2( | |
| condition_images, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| if refer_image is not None: | |
| logger.debug( | |
| f"center crop resize refer_image to height={height}, width={width}" | |
| ) | |
| refer_image = batch_dynamic_crop_resize_images_v2( | |
| refer_image, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| if ip_adapter_image is not None: | |
| logger.debug( | |
| f"center crop resize ip_adapter_image to height={height}, width={width}" | |
| ) | |
| ip_adapter_image = batch_dynamic_crop_resize_images_v2( | |
| ip_adapter_image, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| if refer_face_image is not None: | |
| logger.debug( | |
| f"center crop resize refer_face_image to height={height}, width={width}" | |
| ) | |
| refer_face_image = batch_dynamic_crop_resize_images_v2( | |
| refer_face_image, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| run_video_length = video_length | |
| # generate vision condition frame start | |
| # if condition_images is None, generate with refer_image, ip_adapter_image | |
| # if condition_images not None and need redraw, according to redraw_condition_image_with_ipdapter, redraw_condition_image_with_referencenet, refer_image, ip_adapter_image | |
| if n_vision_condition > 0: | |
| if condition_images is None and condition_latents is None: | |
| logger.debug("run_pipe_text2video, generate first_image") | |
| ( | |
| condition_images, | |
| condition_latents, | |
| _, | |
| _, | |
| _, | |
| ) = self.pipeline( | |
| prompt=prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| negative_prompt=negative_prompt, | |
| video_length=1, | |
| height=height, | |
| width=width, | |
| return_dict=False, | |
| skip_temporal_layer=True, | |
| output_type="np", | |
| generator=generator, | |
| w_ind_noise=w_ind_noise, | |
| need_img_based_video_noise=need_img_based_video_noise, | |
| refer_image=refer_image | |
| if redraw_condition_image_with_referencenet | |
| else None, | |
| ip_adapter_image=ip_adapter_image | |
| if redraw_condition_image_with_ipdapter | |
| else None, | |
| refer_face_image=refer_face_image | |
| if redraw_condition_image_with_facein | |
| else None, | |
| ip_adapter_scale=ip_adapter_scale, | |
| facein_scale=facein_scale, | |
| ip_adapter_face_scale=ip_adapter_face_scale, | |
| ip_adapter_face_image=refer_face_image | |
| if redraw_condition_image_with_ip_adapter_face | |
| else None, | |
| prompt_only_use_image_prompt=prompt_only_use_image_prompt, | |
| ) | |
| run_video_length = video_length - 1 | |
| elif ( | |
| condition_images is not None | |
| and redraw_condition_image | |
| and condition_latents is None | |
| ): | |
| logger.debug("run_pipe_text2video, redraw first_image") | |
| ( | |
| condition_images, | |
| condition_latents, | |
| _, | |
| _, | |
| _, | |
| ) = self.pipeline( | |
| prompt=prompt, | |
| image=condition_images, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| negative_prompt=negative_prompt, | |
| strength=strength, | |
| video_length=condition_images.shape[2], | |
| height=height, | |
| width=width, | |
| return_dict=False, | |
| skip_temporal_layer=True, | |
| output_type="np", | |
| generator=generator, | |
| w_ind_noise=w_ind_noise, | |
| need_img_based_video_noise=need_img_based_video_noise, | |
| refer_image=refer_image | |
| if redraw_condition_image_with_referencenet | |
| else None, | |
| ip_adapter_image=ip_adapter_image | |
| if redraw_condition_image_with_ipdapter | |
| else None, | |
| refer_face_image=refer_face_image | |
| if redraw_condition_image_with_facein | |
| else None, | |
| ip_adapter_scale=ip_adapter_scale, | |
| facein_scale=facein_scale, | |
| ip_adapter_face_scale=ip_adapter_face_scale, | |
| ip_adapter_face_image=refer_face_image | |
| if redraw_condition_image_with_ip_adapter_face | |
| else None, | |
| prompt_only_use_image_prompt=prompt_only_use_image_prompt, | |
| ) | |
| else: | |
| condition_images = None | |
| condition_latents = None | |
| # generate vision condition frame end | |
| # refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above start | |
| if ( | |
| refer_image is not None | |
| and redraw_condition_image | |
| and condition_images is not None | |
| ): | |
| refer_image = condition_images * 255.0 | |
| logger.debug(f"update refer_image because of redraw_condition_image") | |
| elif ( | |
| refer_image is None | |
| and self.pipeline.referencenet is not None | |
| and condition_images is not None | |
| ): | |
| refer_image = condition_images * 255.0 | |
| logger.debug(f"update refer_image because of generate first_image") | |
| # ipadapter_image | |
| if ( | |
| ip_adapter_image is not None | |
| and redraw_condition_image | |
| and condition_images is not None | |
| ): | |
| ip_adapter_image = condition_images * 255.0 | |
| logger.debug(f"update ip_adapter_image because of redraw_condition_image") | |
| elif ( | |
| ip_adapter_image is None | |
| and self.pipeline.ip_adapter_image_proj is not None | |
| and condition_images is not None | |
| ): | |
| ip_adapter_image = condition_images * 255.0 | |
| logger.debug(f"update ip_adapter_image because of generate first_image") | |
| # refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above end | |
| # refer_face_image, update mode from 2 and 3 as mentioned above start | |
| if ( | |
| refer_face_image is not None | |
| and redraw_condition_image | |
| and condition_images is not None | |
| ): | |
| refer_face_image = condition_images * 255.0 | |
| logger.debug(f"update refer_face_image because of redraw_condition_image") | |
| elif ( | |
| refer_face_image is None | |
| and self.pipeline.facein_image_proj is not None | |
| and condition_images is not None | |
| ): | |
| refer_face_image = condition_images * 255.0 | |
| logger.debug(f"update face_image because of generate first_image") | |
| # refer_face_image, update mode from 2 and 3 as mentioned above end | |
| last_mid_video_noises = None | |
| last_mid_video_latents = None | |
| initial_common_latent = None | |
| out_videos = [] | |
| for i_batch in range(max_batch_num): | |
| logger.debug(f"sd_pipeline_predictor, run_pipe_text2video: {i_batch}") | |
| if max_batch_num is not None and i_batch == max_batch_num: | |
| break | |
| if i_batch == 0: | |
| result_overlap = 0 | |
| else: | |
| if n_vision_condition > 0: | |
| # ignore condition_images if condition_latents is not None in pipeline | |
| if not fix_condition_images: | |
| logger.debug(f"{i_batch}, update condition_latents") | |
| condition_latents = out_latents_batch[ | |
| :, :, -n_vision_condition:, :, : | |
| ] | |
| else: | |
| logger.debug(f"{i_batch}, do not update condition_latents") | |
| result_overlap = n_vision_condition | |
| if not fixed_refer_image and n_vision_condition > 0: | |
| logger.debug("ref_image use last frame of last generated out video") | |
| refer_image = out_batch[:, :, -n_vision_condition:, :, :] * 255.0 | |
| else: | |
| logger.debug("use given fixed ref_image") | |
| if not fixed_ip_adapter_image and n_vision_condition > 0: | |
| logger.debug( | |
| "ip_adapter_image use last frame of last generated out video" | |
| ) | |
| ip_adapter_image = ( | |
| out_batch[:, :, -n_vision_condition:, :, :] * 255.0 | |
| ) | |
| else: | |
| logger.debug("use given fixed ip_adapter_image") | |
| if not fixed_refer_face_image and n_vision_condition > 0: | |
| logger.debug( | |
| "refer_face_image use last frame of last generated out video" | |
| ) | |
| refer_face_image = ( | |
| out_batch[:, :, -n_vision_condition:, :, :] * 255.0 | |
| ) | |
| else: | |
| logger.debug("use given fixed ip_adapter_image") | |
| run_video_length = video_length | |
| if same_seed is not None: | |
| _, generator = set_all_seed(same_seed) | |
| out = self.pipeline( | |
| video_length=run_video_length, # int | |
| prompt=prompt, | |
| num_inference_steps=video_num_inference_steps, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| condition_images=condition_images, | |
| condition_latents=condition_latents, # b co t(1) ho wo | |
| skip_temporal_layer=False, | |
| output_type="np", | |
| noise_type=noise_type, | |
| negative_prompt=video_negative_prompt, | |
| guidance_scale=video_guidance_scale, | |
| guidance_scale_end=video_guidance_scale_end, | |
| guidance_scale_method=video_guidance_scale_method, | |
| w_ind_noise=w_ind_noise, | |
| need_img_based_video_noise=need_img_based_video_noise, | |
| img_weight=img_weight, | |
| motion_speed=motion_speed, | |
| vision_condition_latent_index=vision_condition_latent_index, | |
| refer_image=refer_image, | |
| ip_adapter_image=ip_adapter_image, | |
| refer_face_image=refer_face_image, | |
| ip_adapter_scale=ip_adapter_scale, | |
| facein_scale=facein_scale, | |
| ip_adapter_face_scale=ip_adapter_face_scale, | |
| ip_adapter_face_image=refer_face_image, | |
| prompt_only_use_image_prompt=prompt_only_use_image_prompt, | |
| initial_common_latent=initial_common_latent, | |
| # serial_denoise parameter start | |
| record_mid_video_noises=record_mid_video_noises, | |
| last_mid_video_noises=last_mid_video_noises, | |
| record_mid_video_latents=record_mid_video_latents, | |
| last_mid_video_latents=last_mid_video_latents, | |
| video_overlap=video_overlap, | |
| # serial_denoise parameter end | |
| # parallel_denoise parameter start | |
| context_schedule=context_schedule, | |
| context_frames=context_frames, | |
| context_stride=context_stride, | |
| context_overlap=context_overlap, | |
| context_batch_size=context_batch_size, | |
| interpolation_factor=interpolation_factor, | |
| # parallel_denoise parameter end | |
| ) | |
| logger.debug( | |
| f"run_pipe_text2video, out.videos.shape, i_batch={i_batch}, videos={out.videos.shape}, result_overlap={result_overlap}" | |
| ) | |
| out_batch = out.videos[:, :, result_overlap:, :, :] | |
| out_latents_batch = out.latents[:, :, result_overlap:, :, :] | |
| out_videos.append(out_batch) | |
| out_videos = np.concatenate(out_videos, axis=2) | |
| if need_hist_match: | |
| out_videos[:, :, 1:, :, :] = hist_match_video_bcthw( | |
| out_videos[:, :, 1:, :, :], out_videos[:, :, :1, :, :], value=255.0 | |
| ) | |
| return out_videos | |
| def run_pipe_with_latent_input( | |
| self, | |
| ): | |
| pass | |
| def run_pipe_middle2video_with_middle(self, middle: Tuple[str, Iterable]): | |
| pass | |
| def run_pipe_video2video( | |
| self, | |
| video: Tuple[str, Iterable], | |
| time_size: int = None, | |
| sample_rate: int = None, | |
| overlap: int = None, | |
| step: int = None, | |
| prompt: Union[str, List[str]] = None, | |
| # b c t h w | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| video_num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| video_guidance_scale: float = 7.5, | |
| video_guidance_scale_end: float = 3.5, | |
| video_guidance_scale_method: str = "linear", | |
| video_negative_prompt: Optional[Union[str, List[str]]] = None, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, | |
| # b c t(1) hi wi | |
| controlnet_condition_images: Optional[torch.FloatTensor] = None, | |
| # b c t(1) ho wo | |
| controlnet_condition_latents: Optional[torch.FloatTensor] = None, | |
| # b c t(1) ho wo | |
| condition_latents: Optional[torch.FloatTensor] = None, | |
| condition_images: Optional[torch.FloatTensor] = None, | |
| fix_condition_images: bool = False, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "tensor", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: int = 1, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | |
| guess_mode: bool = False, | |
| control_guidance_start: Union[float, List[float]] = 0.0, | |
| control_guidance_end: Union[float, List[float]] = 1.0, | |
| need_middle_latents: bool = False, | |
| w_ind_noise: float = 0.5, | |
| img_weight: float = 0.001, | |
| initial_common_latent: Optional[torch.FloatTensor] = None, | |
| latent_index: torch.LongTensor = None, | |
| vision_condition_latent_index: torch.LongTensor = None, | |
| noise_type: str = "random", | |
| controlnet_processor_params: Dict = None, | |
| need_return_videos: bool = False, | |
| need_return_condition: bool = False, | |
| max_batch_num: int = 30, | |
| strength: float = 0.8, | |
| video_strength: float = 0.8, | |
| need_video2video: bool = False, | |
| need_img_based_video_noise: bool = False, | |
| need_hist_match: bool = False, | |
| end_to_end: bool = True, | |
| refer_image: Optional[ | |
| Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]] | |
| ] = None, | |
| ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| fixed_refer_image: bool = True, | |
| fixed_ip_adapter_image: bool = True, | |
| redraw_condition_image: bool = False, | |
| redraw_condition_image_with_ipdapter: bool = True, | |
| redraw_condition_image_with_referencenet: bool = True, | |
| refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, | |
| fixed_refer_face_image: bool = True, | |
| redraw_condition_image_with_facein: bool = True, | |
| ip_adapter_scale: float = 1.0, | |
| facein_scale: float = 1.0, | |
| ip_adapter_face_scale: float = 1.0, | |
| redraw_condition_image_with_ip_adapter_face: bool = False, | |
| n_vision_condition: int = 1, | |
| prompt_only_use_image_prompt: bool = False, | |
| motion_speed: float = 8.0, | |
| # serial_denoise parameter start | |
| record_mid_video_noises: bool = False, | |
| record_mid_video_latents: bool = False, | |
| video_overlap: int = 1, | |
| # serial_denoise parameter end | |
| # parallel_denoise parameter start | |
| context_schedule="uniform", | |
| context_frames=12, | |
| context_stride=1, | |
| context_overlap=4, | |
| context_batch_size=1, | |
| interpolation_factor=1, | |
| # parallel_denoise parameter end | |
| # 支持 video_path 时多种输入 | |
| # TODO:// video_has_condition =False,当且仅支持 video_is_middle=True, 待后续重构 | |
| # TODO:// when video_has_condition =False, video_is_middle should be True. | |
| video_is_middle: bool = False, | |
| video_has_condition: bool = True, | |
| ): | |
| """ | |
| 类似controlnet text2img pipeline。 输入视频,用视频得到controlnet condition。 | |
| 目前仅支持time_size == step,overlap=0 | |
| 输出视频长度=输入视频长度 | |
| similar to controlnet text2image pipeline, generate video with controlnet condition from given video. | |
| By now, sliding window only support time_size == step, overlap = 0. | |
| """ | |
| if isinstance(video, str): | |
| video_reader = DecordVideoDataset( | |
| video, | |
| time_size=time_size, | |
| step=step, | |
| overlap=overlap, | |
| sample_rate=sample_rate, | |
| device="cpu", | |
| data_type="rgb", | |
| channels_order="c t h w", | |
| drop_last=True, | |
| ) | |
| else: | |
| video_reader = video | |
| videos = [] if need_return_videos else None | |
| out_videos = [] | |
| out_condition = ( | |
| [] | |
| if need_return_condition and self.pipeline.controlnet is not None | |
| else None | |
| ) | |
| # crop resize images | |
| if condition_images is not None: | |
| logger.debug( | |
| f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}" | |
| ) | |
| condition_images = batch_dynamic_crop_resize_images_v2( | |
| condition_images, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| if refer_image is not None: | |
| logger.debug( | |
| f"center crop resize refer_image to height={height}, width={width}" | |
| ) | |
| refer_image = batch_dynamic_crop_resize_images_v2( | |
| refer_image, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| if ip_adapter_image is not None: | |
| logger.debug( | |
| f"center crop resize ip_adapter_image to height={height}, width={width}" | |
| ) | |
| ip_adapter_image = batch_dynamic_crop_resize_images_v2( | |
| ip_adapter_image, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| if refer_face_image is not None: | |
| logger.debug( | |
| f"center crop resize refer_face_image to height={height}, width={width}" | |
| ) | |
| refer_face_image = batch_dynamic_crop_resize_images_v2( | |
| refer_face_image, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| first_image = None | |
| last_mid_video_noises = None | |
| last_mid_video_latents = None | |
| initial_common_latent = None | |
| # initial_common_latent = torch.randn((1, 4, 1, 112, 64)).to( | |
| # device=self.device, dtype=self.dtype | |
| # ) | |
| for i_batch, item in enumerate(video_reader): | |
| logger.debug(f"\n sd_pipeline_predictor, run_pipe_video2video: {i_batch}") | |
| if max_batch_num is not None and i_batch == max_batch_num: | |
| break | |
| # read and prepare video batch | |
| batch = item.data | |
| batch = batch_dynamic_crop_resize_images( | |
| batch, | |
| target_height=height, | |
| target_width=width, | |
| ) | |
| batch = batch[np.newaxis, ...] | |
| batch_size, channel, video_length, video_height, video_width = batch.shape | |
| # extract controlnet middle | |
| if self.pipeline.controlnet is not None: | |
| batch = rearrange(batch, "b c t h w-> (b t) h w c") | |
| controlnet_processor_params = update_controlnet_processor_params( | |
| src=self.controlnet_processor_params, | |
| dst=controlnet_processor_params, | |
| ) | |
| if not video_is_middle: | |
| batch_condition = self.controlnet_processor( | |
| data=batch, | |
| data_channel_order="b h w c", | |
| target_height=height, | |
| target_width=width, | |
| return_type="np", | |
| return_data_channel_order="b c h w", | |
| input_rgb_order="rgb", | |
| processor_params=controlnet_processor_params, | |
| ) | |
| else: | |
| # TODO: 临时用于可视化输入的 controlnet middle 序列,后续待拆到 middl2video中,也可以增加参数支持 | |
| # TODO: only use video_path is controlnet middle output, to improved | |
| batch_condition = rearrange( | |
| copy.deepcopy(batch), " b h w c-> b c h w" | |
| ) | |
| # 当前仅当 输入是 middle、condition_image的pose在middle首帧之前,需要重新生成condition_images的pose并绑定到middle_batch上 | |
| # when video_path is middle seq and condition_image is not aligned with middle seq, | |
| # regenerate codntion_images pose, and then concat into middle_batch, | |
| if ( | |
| i_batch == 0 | |
| and not video_has_condition | |
| and video_is_middle | |
| and condition_images is not None | |
| ): | |
| condition_images_reshape = rearrange( | |
| condition_images, "b c t h w-> (b t) h w c" | |
| ) | |
| condition_images_condition = self.controlnet_processor( | |
| data=condition_images_reshape, | |
| data_channel_order="b h w c", | |
| target_height=height, | |
| target_width=width, | |
| return_type="np", | |
| return_data_channel_order="b c h w", | |
| input_rgb_order="rgb", | |
| processor_params=controlnet_processor_params, | |
| ) | |
| condition_images_condition = rearrange( | |
| condition_images_condition, | |
| "(b t) c h w-> b c t h w", | |
| b=batch_size, | |
| ) | |
| else: | |
| condition_images_condition = None | |
| if not isinstance(batch_condition, list): | |
| batch_condition = rearrange( | |
| batch_condition, "(b t) c h w-> b c t h w", b=batch_size | |
| ) | |
| if condition_images_condition is not None: | |
| batch_condition = np.concatenate( | |
| [ | |
| condition_images_condition, | |
| batch_condition, | |
| ], | |
| axis=2, | |
| ) | |
| # 此时 batch_condition 比 batch 多了一帧,为了最终视频能 concat 存储,替换下 | |
| # 当前仅适用于 condition_images_condition 不为None | |
| # when condition_images_condition is not None, batch_condition has more frames than batch | |
| batch = rearrange(batch_condition, "b c t h w ->(b t) h w c") | |
| else: | |
| batch_condition = [ | |
| rearrange(x, "(b t) c h w-> b c t h w", b=batch_size) | |
| for x in batch_condition | |
| ] | |
| if condition_images_condition is not None: | |
| batch_condition = [ | |
| np.concatenate( | |
| [condition_images_condition, batch_condition_tmp], | |
| axis=2, | |
| ) | |
| for batch_condition_tmp in batch_condition | |
| ] | |
| batch = rearrange(batch, "(b t) h w c -> b c t h w", b=batch_size) | |
| else: | |
| batch_condition = None | |
| # condition [0,255] | |
| # latent: [0,1] | |
| # 按需求生成多个片段, | |
| # generate multi video_shot | |
| # 第一个片段 会特殊处理,需要生成首帧 | |
| # first shot is special because of first frame. | |
| # 后续片段根据拿前一个片段结果,首尾相连的方式生成。 | |
| # use last frame of last shot as the first frame of the current shot | |
| # TODO: 当前独立拆开实现,待后续融合到一起实现 | |
| # TODO: to optimize implementation way | |
| if n_vision_condition == 0: | |
| actual_video_length = video_length | |
| control_image = batch_condition | |
| first_image_controlnet_condition = None | |
| first_image_latents = None | |
| if need_video2video: | |
| video = batch | |
| else: | |
| video = None | |
| result_overlap = 0 | |
| else: | |
| if i_batch == 0: | |
| if self.pipeline.controlnet is not None: | |
| if not isinstance(batch_condition, list): | |
| first_image_controlnet_condition = batch_condition[ | |
| :, :, :1, :, : | |
| ] | |
| else: | |
| first_image_controlnet_condition = [ | |
| x[:, :, :1, :, :] for x in batch_condition | |
| ] | |
| else: | |
| first_image_controlnet_condition = None | |
| if need_video2video: | |
| if condition_images is None: | |
| video = batch[:, :, :1, :, :] | |
| else: | |
| video = condition_images | |
| else: | |
| video = None | |
| if condition_images is not None and not redraw_condition_image: | |
| first_image = condition_images | |
| first_image_latents = None | |
| else: | |
| ( | |
| first_image, | |
| first_image_latents, | |
| _, | |
| _, | |
| _, | |
| ) = self.pipeline( | |
| prompt=prompt, | |
| image=video, | |
| control_image=first_image_controlnet_condition, | |
| num_inference_steps=num_inference_steps, | |
| video_length=1, | |
| height=height, | |
| width=width, | |
| return_dict=False, | |
| skip_temporal_layer=True, | |
| output_type="np", | |
| generator=generator, | |
| negative_prompt=negative_prompt, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| control_guidance_start=control_guidance_start, | |
| control_guidance_end=control_guidance_end, | |
| w_ind_noise=w_ind_noise, | |
| strength=strength, | |
| refer_image=refer_image | |
| if redraw_condition_image_with_referencenet | |
| else None, | |
| ip_adapter_image=ip_adapter_image | |
| if redraw_condition_image_with_ipdapter | |
| else None, | |
| refer_face_image=refer_face_image | |
| if redraw_condition_image_with_facein | |
| else None, | |
| ip_adapter_scale=ip_adapter_scale, | |
| facein_scale=facein_scale, | |
| ip_adapter_face_scale=ip_adapter_face_scale, | |
| ip_adapter_face_image=refer_face_image | |
| if redraw_condition_image_with_ip_adapter_face | |
| else None, | |
| prompt_only_use_image_prompt=prompt_only_use_image_prompt, | |
| ) | |
| if refer_image is not None: | |
| refer_image = first_image * 255.0 | |
| if ip_adapter_image is not None: | |
| ip_adapter_image = first_image * 255.0 | |
| # 首帧用于后续推断可以直接用first_image_latent不需要 first_image了 | |
| first_image = None | |
| if self.pipeline.controlnet is not None: | |
| if not isinstance(batch_condition, list): | |
| control_image = batch_condition[:, :, 1:, :, :] | |
| logger.debug(f"control_image={control_image.shape}") | |
| else: | |
| control_image = [x[:, :, 1:, :, :] for x in batch_condition] | |
| else: | |
| control_image = None | |
| actual_video_length = time_size - int(video_has_condition) | |
| if need_video2video: | |
| video = batch[:, :, 1:, :, :] | |
| else: | |
| video = None | |
| result_overlap = 0 | |
| else: | |
| actual_video_length = time_size | |
| if self.pipeline.controlnet is not None: | |
| if not fix_condition_images: | |
| logger.debug( | |
| f"{i_batch}, update first_image_controlnet_condition" | |
| ) | |
| if not isinstance(last_batch_condition, list): | |
| first_image_controlnet_condition = last_batch_condition[ | |
| :, :, -1:, :, : | |
| ] | |
| else: | |
| first_image_controlnet_condition = [ | |
| x[:, :, -1:, :, :] for x in last_batch_condition | |
| ] | |
| else: | |
| logger.debug( | |
| f"{i_batch}, do not update first_image_controlnet_condition" | |
| ) | |
| control_image = batch_condition | |
| else: | |
| control_image = None | |
| first_image_controlnet_condition = None | |
| if not fix_condition_images: | |
| logger.debug(f"{i_batch}, update condition_images") | |
| first_image_latents = out_latents_batch[:, :, -1:, :, :] | |
| else: | |
| logger.debug(f"{i_batch}, do not update condition_images") | |
| if need_video2video: | |
| video = batch | |
| else: | |
| video = None | |
| result_overlap = 1 | |
| # 更新 ref_image和 ipadapter_image | |
| if not fixed_refer_image: | |
| logger.debug( | |
| "ref_image use last frame of last generated out video" | |
| ) | |
| refer_image = ( | |
| out_batch[:, :, -n_vision_condition:, :, :] * 255.0 | |
| ) | |
| else: | |
| logger.debug("use given fixed ref_image") | |
| if not fixed_ip_adapter_image: | |
| logger.debug( | |
| "ip_adapter_image use last frame of last generated out video" | |
| ) | |
| ip_adapter_image = ( | |
| out_batch[:, :, -n_vision_condition:, :, :] * 255.0 | |
| ) | |
| else: | |
| logger.debug("use given fixed ip_adapter_image") | |
| # face image | |
| if not fixed_ip_adapter_image: | |
| logger.debug( | |
| "refer_face_image use last frame of last generated out video" | |
| ) | |
| refer_face_image = ( | |
| out_batch[:, :, -n_vision_condition:, :, :] * 255.0 | |
| ) | |
| else: | |
| logger.debug("use given fixed ip_adapter_image") | |
| out = self.pipeline( | |
| video_length=actual_video_length, # int | |
| prompt=prompt, | |
| num_inference_steps=video_num_inference_steps, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| image=video, | |
| control_image=control_image, # b ci(3) t hi wi | |
| controlnet_condition_images=first_image_controlnet_condition, # b ci(3) t(1) hi wi | |
| # controlnet_condition_images=np.zeros_like( | |
| # first_image_controlnet_condition | |
| # ), # b ci(3) t(1) hi wi | |
| condition_images=first_image, | |
| condition_latents=first_image_latents, # b co t(1) ho wo | |
| skip_temporal_layer=False, | |
| output_type="np", | |
| noise_type=noise_type, | |
| negative_prompt=video_negative_prompt, | |
| need_img_based_video_noise=need_img_based_video_noise, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| control_guidance_start=control_guidance_start, | |
| control_guidance_end=control_guidance_end, | |
| w_ind_noise=w_ind_noise, | |
| img_weight=img_weight, | |
| motion_speed=video_reader.sample_rate, | |
| guidance_scale=video_guidance_scale, | |
| guidance_scale_end=video_guidance_scale_end, | |
| guidance_scale_method=video_guidance_scale_method, | |
| strength=video_strength, | |
| refer_image=refer_image, | |
| ip_adapter_image=ip_adapter_image, | |
| refer_face_image=refer_face_image, | |
| ip_adapter_scale=ip_adapter_scale, | |
| facein_scale=facein_scale, | |
| ip_adapter_face_scale=ip_adapter_face_scale, | |
| ip_adapter_face_image=refer_face_image, | |
| prompt_only_use_image_prompt=prompt_only_use_image_prompt, | |
| initial_common_latent=initial_common_latent, | |
| # serial_denoise parameter start | |
| record_mid_video_noises=record_mid_video_noises, | |
| last_mid_video_noises=last_mid_video_noises, | |
| record_mid_video_latents=record_mid_video_latents, | |
| last_mid_video_latents=last_mid_video_latents, | |
| video_overlap=video_overlap, | |
| # serial_denoise parameter end | |
| # parallel_denoise parameter start | |
| context_schedule=context_schedule, | |
| context_frames=context_frames, | |
| context_stride=context_stride, | |
| context_overlap=context_overlap, | |
| context_batch_size=context_batch_size, | |
| interpolation_factor=interpolation_factor, | |
| # parallel_denoise parameter end | |
| ) | |
| last_batch = batch | |
| last_batch_condition = batch_condition | |
| last_mid_video_latents = out.mid_video_latents | |
| last_mid_video_noises = out.mid_video_noises | |
| out_batch = out.videos[:, :, result_overlap:, :, :] | |
| out_latents_batch = out.latents[:, :, result_overlap:, :, :] | |
| out_videos.append(out_batch) | |
| if need_return_videos: | |
| videos.append(batch) | |
| if out_condition is not None: | |
| out_condition.append(batch_condition) | |
| out_videos = np.concatenate(out_videos, axis=2) | |
| if need_return_videos: | |
| videos = np.concatenate(videos, axis=2) | |
| if out_condition is not None: | |
| if not isinstance(out_condition[0], list): | |
| out_condition = np.concatenate(out_condition, axis=2) | |
| else: | |
| out_condition = [ | |
| [out_condition[j][i] for j in range(len(out_condition))] | |
| for i in range(len(out_condition[0])) | |
| ] | |
| out_condition = [np.concatenate(x, axis=2) for x in out_condition] | |
| if need_hist_match: | |
| videos[:, :, 1:, :, :] = hist_match_video_bcthw( | |
| videos[:, :, 1:, :, :], videos[:, :, :1, :, :], value=255.0 | |
| ) | |
| return out_videos, out_condition, videos | |