diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index df23076..137fd28 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -774,18 +774,11 @@ class WanVideoVAE(nn.Module): def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): - hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] - videos = [] - for hidden_state in hidden_states: - hidden_state = hidden_state.unsqueeze(0) - if tiled: - video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) - else: - video = self.single_decode(hidden_state, device) - video = video.squeeze(0) - videos.append(video) - videos = torch.stack(videos) - return videos + if tiled: + video = self.tiled_decode(hidden_states, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_states, device) + return video @staticmethod diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py new file mode 100644 index 0000000..dcc4485 --- /dev/null +++ b/diffsynth/pipelines/wan_video_new.py @@ -0,0 +1,981 @@ +import torch, warnings, glob +import numpy as np +from PIL import Image +from einops import repeat, reduce +from typing import Optional, Union +from dataclasses import dataclass +from modelscope import snapshot_download + + +import types +from ..models import ModelManager +from ..models.wan_video_dit import WanModel +from ..models.wan_video_text_encoder import WanTextEncoder +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline +from ..prompters import WanPrompter +import torch, os +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional + +from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear +from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm +from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample +from ..models.wan_video_motion_controller import WanMotionControllerModel + + + +class BasePipeline(torch.nn.Module): + + def __init__( + self, + device="cuda", torch_dtype=torch.float16, + height_division_factor=64, width_division_factor=64, + time_division_factor=None, time_division_remainder=None, + ): + super().__init__() + # The device and torch_dtype is used for the storage of intermediate variables, not models. + self.device = device + self.torch_dtype = torch_dtype + # The following parameters are used for shape check. + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + self.vram_management_enabled = False + + + def check_resize_height_width(self, height, width, num_frames=None): + # Shape check + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + print(f"height % {self.height_division_factor} != 0. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + print(f"width % {self.width_division_factor} != 0. We round it up to {width}.") + if num_frames is None: + return height, width + else: + if num_frames % self.time_division_factor != self.time_division_remainder: + num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder + print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.") + return height, width, num_frames + + + def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): + # Transform a PIL.Image to torch.Tensor + image = torch.Tensor(np.array(image, dtype=np.float32)) + image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + image = image * ((max_value - min_value) / 255) + min_value + image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})) + return image + + + def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a list of PIL.Image to torch.Tensor + video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] + video = torch.stack(video, dim=pattern.index("T") // 2) + return video + + + def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to PIL.Image + if pattern != "H W C": + vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean") + image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255) + image = image.to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(image.numpy()) + return image + + + def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1): + # Transform a torch.Tensor to list of PIL.Image + if pattern != "T H W C": + vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean") + video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] + return video + + + def load_models_to_device(self, model_names=[]): + if self.vram_management_enabled: + # offload models + for name, model in self.named_children(): + if name not in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + else: + model.cpu() + torch.cuda.empty_cache() + # onload models + for name, model in self.named_children(): + if name in model_names: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + else: + model.to(self.device) + + + def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None): + # Initialize Gaussian noise + generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype) + noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + return noise + + + def enable_cpu_offload(self): + warnings.warn("enable_cpu_offload is deprecated. This feature is automatically enabled if offload_device != device") + + + +@dataclass +class ModelConfig: + path: Union[str, list[str]] = None + model_id: str = None + origin_file_pattern: Union[str, list[str]] = None + download_resource: str = "ModelScope" + offload_device: Optional[Union[str, torch.device]] = None + quantization_dtype: Optional[torch.dtype] = None + + def download_if_necessary(self, local_model_path="./models", skip_download=False): + if self.path is None: + if self.model_id is None or self.origin_file_pattern is None: + raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") + if not skip_download: + snapshot_download( + self.model_id, + local_dir=os.path.join(local_model_path, self.model_id), + allow_file_pattern=self.origin_file_pattern, + local_files_only=False + ) + self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) + if isinstance(self.path, list) and len(self.path) == 1: + self.path = self.path[0] + + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.prompter = WanPrompter(tokenizer_path=tokenizer_path) + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.motion_controller: WanMotionControllerModel = None + self.vace: VaceWanModel = None + self.in_iteration_models = ("dit", "motion_controller", "vace") + self.unit_runner = PipelineUnitRunner() + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_ImageEmbedder(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger(), + ] + + + def enable_vram_management(self, num_persistent_param_in_dit=None): + self.vram_management_enabled = True + if self.text_encoder is not None: + dtype = next(iter(self.text_encoder.parameters())).dtype + enable_vram_management( + self.text_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + T5RelativeEmbedding: AutoWrappedModule, + T5LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.dit is not None: + dtype = next(iter(self.dit.parameters())).dtype + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.vae is not None: + dtype = next(iter(self.vae.parameters())).dtype + enable_vram_management( + self.vae, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + RMS_norm: AutoWrappedModule, + CausalConv3d: AutoWrappedModule, + Upsample: AutoWrappedModule, + torch.nn.SiLU: AutoWrappedModule, + torch.nn.Dropout: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.image_encoder is not None: + dtype = next(iter(self.image_encoder.parameters())).dtype + enable_vram_management( + self.image_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) + if self.motion_controller is not None: + dtype = next(iter(self.motion_controller.parameters())).dtype + enable_vram_management( + self.motion_controller, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) + if self.vace is not None: + enable_vram_management( + self.vace, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), + local_model_path: str = "./models", + skip_download: bool = False + ): + # Download and load models + model_manager = ModelManager() + for model_config in model_configs: + model_config.download_if_necessary(local_model_path, skip_download=skip_download) + model_manager.load_model( + model_config.path, + device=model_config.offload_device or device, + torch_dtype=model_config.quantization_dtype or torch_dtype + ) + + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") + pipe.dit = model_manager.fetch_model("wan_video_dit") + pipe.vae = model_manager.fetch_model("wan_video_vae") + pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") + pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") + pipe.vace = model_manager.fetch_model("wan_video_vace") + + # Initialize tokenizer + tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download) + pipe.prompter.fetch_models(pipe.text_encoder) + pipe.prompter.fetch_tokenizer(tokenizer_config.path) + return pipe + + + def denoising_model(self): + return self.dit + + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt, + negative_prompt="", + # Image-to-video + input_image=None, + # First-last-frame-to-video + end_image=None, + # Video-to-video + input_video=None, + denoising_strength=1.0, + # ControlNet + control_video=None, + reference_image=None, + # VACE + vace_video=None, + vace_video_mask=None, + vace_reference_image=None, + vace_scale=1.0, + # Randomness + seed=None, + rand_device="cpu", + # Shape + height=480, + width=832, + num_frames=81, + # Classifier-free guidance + cfg_scale=5.0, + cfg_merge=False, + # Scheduler + num_inference_steps=50, + sigma_shift=5.0, + # Speed control + motion_bucket_id=None, + # VAE tiling + tiled=True, + tile_size=(30, 52), + tile_stride=(15, 26), + # Sliding window + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + # Teacache + tea_cache_l1_thresh=None, + tea_cache_model_id="", + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, + } + inputs_nega = { + "negative_prompt": negative_prompt, + "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "input_video": input_video, "denoising_strength": denoising_strength, + "control_video": control_video, "reference_image": reference_image, + "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift, + "motion_bucket_id": motion_bucket_id, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + + # Inference + noise_pred_posi = model_fn_wan_video(**models, **inputs_shared, **inputs_posi, timestep=timestep) + if cfg_scale != 1.0: + if cfg_merge: + noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0) + else: + noise_pred_nega = model_fn_wan_video(**models, **inputs_shared, **inputs_nega, timestep=timestep) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + + # VACE (TODO: remove it) + if vace_reference_image is not None: + latents = latents[:, :, 1:] + + # Decode + self.load_models_to_device(['vae']) + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device([]) + + return video + + + +class PipelineUnit: + def __init__( + self, + seperate_cfg: bool = False, + take_over: bool = False, + input_params: tuple[str] = None, + input_params_posi: dict[str, str] = None, + input_params_nega: dict[str, str] = None, + onload_model_names: tuple[str] = None + ): + self.seperate_cfg = seperate_cfg + self.take_over = take_over + self.input_params = input_params + self.input_params_posi = input_params_posi + self.input_params_nega = input_params_nega + self.onload_model_names = onload_model_names + + + def process(self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs) -> dict: + raise NotImplementedError("`process` is not implemented.") + + + +class PipelineUnitRunner: + def __init__(self): + pass + + def __call__(self, unit: PipelineUnit, pipe: WanVideoPipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + if unit.take_over: + # Let the pipeline unit take over this function. + inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) + elif unit.seperate_cfg: + # Positive side + processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_posi.update(processor_outputs) + # Negative side + if inputs_shared["cfg_scale"] != 1: + processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_nega.update(processor_outputs) + else: + inputs_nega.update(processor_outputs) + else: + processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params} + processor_outputs = unit.process(pipe, **processor_inputs) + inputs_shared.update(processor_outputs) + return inputs_shared, inputs_posi, inputs_nega + + + +class WanVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "num_frames")) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + + +class WanVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image")) + + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + length = (num_frames - 1) // 4 + 1 + if vace_reference_image is not None: + length += 1 + noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device) + if vace_reference_image is not None: + noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) + return {"noise": noise} + + + +class WanVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride): + if input_video is None: + return {"latents": noise} + pipe.load_models_to_device(["vae"]) + input_video = pipe.preprocess_video(input_video) + latents = pipe.encode_video(input_video, tiled, tile_size, tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + latents = pipe.scheduler.add_noise(latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + + +class WanVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt", "positive": "positive"}, + input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + onload_model_names=("text_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device) + return {"context": prompt_emb} + + + +class WanVideoUnit_ImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("image_encoder", "vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context, "y": y} + + + +class WanVideoUnit_FunControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): + if control_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + control_video = pipe.preprocess_video(control_video) + control_latents = pipe.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + else: + y = y[:, -16:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} + + + +class WanVideoUnit_FunReference(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("reference_image", "height", "width"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, reference_image, height, width): + if reference_image is None: + return {} + pipe.load_models_to_device(["vae"]) + reference_image = reference_image.resize((width, height)) + reference_image = pipe.preprocess_video([reference_image]) + reference_latents = pipe.vae.encode(reference_image, device=pipe.device) + return {"reference_latents": reference_latents} + + + +class WanVideoUnit_SpeedControl(PipelineUnit): + def __init__(self): + super().__init__(input_params=("motion_bucket_id",)) + + def process(self, pipe: WanVideoPipeline, motion_bucket_id): + if motion_bucket_id is None: + return {} + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device) + return {"motion_bucket_id": motion_bucket_id} + + + +class WanVideoUnit_VACE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process( + self, + pipe: WanVideoPipeline, + vace_video, vace_mask, vace_reference_image, vace_scale, + height, width, num_frames, + tiled, tile_size, tile_stride + ): + if vace_video is not None or vace_mask is not None or vace_reference_image is not None: + pipe.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) + else: + vace_video = pipe.preprocess_video(vace_video) + vace_video = torch.stack(vace_video, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device) + + if vace_mask is None: + vace_mask = torch.ones_like(vace_video) + else: + vace_mask = pipe.preprocess_video(vace_mask) + vace_mask = torch.stack(vace_mask, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device) + + inactive = vace_video * (1 - vace_mask) + 0 * vace_mask + reactive = vace_video * vace_mask + 0 * (1 - vace_mask) + inactive = pipe.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + reactive = pipe.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + vace_reference_image = pipe.preprocess_video([vace_reference_image]) + vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = pipe.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return {"vace_context": None, "vace_scale": vace_scale} + + + +class WanVideoUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"}, + ) + + def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id): + if tea_cache_l1_thresh is None: + return {} + return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)} + + + +class WanVideoUnit_CfgMerger(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if not inputs_shared["cfg_merge"]: + return inputs_shared, inputs_posi, inputs_nega + inputs_shared["context"] = torch.concat((inputs_posi["context"], inputs_nega["context"]), dim=0) + inputs_posi.clear() + inputs_nega.clear() + return inputs_shared, inputs_posi, inputs_nega + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +class TemporalTiler_BCTHW: + def __init__(self): + pass + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, T, _, _ = data.shape + t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0]) + mask = repeat(t, "T -> 1 1 T 1 1") + return mask + + def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names): + tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None] + tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names} + B, C, T, H, W = tensor_dict[tensor_names[0]].shape + data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype + value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype) + weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype) + for t in range(0, T, sliding_window_stride): + if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T: + continue + t_ = min(t + sliding_window_size, T) + model_kwargs.update({ + tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \ + for tensor_name in tensor_names + }) + model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype) + mask = self.build_mask( + model_output, + is_bound=(t == 0, t_ == T), + border_width=(sliding_window_size - sliding_window_stride,) + ).to(device=data_device, dtype=data_dtype) + value[:, :, t: t_, :, :] += model_output * mask + weight[:, :, t: t_, :, :] += mask + value /= weight + model_kwargs.update(tensor_dict) + return value + + + +def model_fn_wan_video( + dit: WanModel, + motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + reference_latents = None, + vace_context = None, + vace_scale = 1.0, + tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, + sliding_window_size: Optional[int] = None, + sliding_window_stride: Optional[int] = None, + **kwargs, +): + if sliding_window_size is not None and sliding_window_stride is not None: + model_kwargs = dict( + dit=dit, + motion_controller=motion_controller, + vace=vace, + latents=latents, + timestep=timestep, + context=context, + clip_feature=clip_feature, + y=y, + reference_latents=reference_latents, + vace_context=vace_context, + vace_scale=vace_scale, + tea_cache=tea_cache, + use_unified_sequence_parallel=use_unified_sequence_parallel, + motion_bucket_id=motion_bucket_id, + ) + return TemporalTiler_BCTHW().run( + model_fn_wan_video, + sliding_window_size, sliding_window_stride, + latents.device, latents.dtype, + model_kwargs=model_kwargs, + tensor_names=["latents", "y"] + ) + + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + x = latents + # Merged cfg + if x.shape[0] != context.shape[0]: + x = torch.concat([x] * context.shape[0], dim=0) + if timestep.shape[0] != context.shape[0]: + timestep = torch.concat([timestep] * context.shape[0], dim=0) + + if dit.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = dit.patchify(x) + + # Reference image + if reference_latents is not None: + reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2) + x = torch.concat([reference_latents, x], dim=1) + f += 1 + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if vace_context is not None: + vace_hints = vace(x, vace_context, context, t_mod, freqs) + + # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + if tea_cache_update: + x = tea_cache.update(x) + else: + for block_id, block in enumerate(dit.blocks): + x = block(x, context, t_mod, freqs) + if vace_context is not None and block_id in vace.vace_layers_mapping: + x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale + if tea_cache is not None: + tea_cache.store(x) + + if reference_latents is not None: + x = x[:, reference_latents.shape[1]:] + f -= 1 + + x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x diff --git a/test.py b/test.py new file mode 100644 index 0000000..2c6cfea --- /dev/null +++ b/test.py @@ -0,0 +1,28 @@ +import torch +from diffsynth import ModelManager, save_video, VideoData, save_frames, save_video, download_models +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from diffsynth.controlnets.processors import Annotator +from modelscope import snapshot_download +from tqdm import tqdm + + +# Load models +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + # ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management(num_persistent_param_in_dit=0) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5)