From 9015d08927331a7b5b559ed17412558279690c33 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 25 Jul 2025 17:09:53 +0800 Subject: [PATCH] support wan2.2 A14B I2V&T2V --- diffsynth/configs/model_config.py | 2 + diffsynth/models/wan_video_dit.py | 19 ++++ diffsynth/pipelines/wan_video_new.py | 96 ++++++++++++++++++- .../model_inference/Wan2.2-I2V-A14B.py | 32 +++++++ .../model_inference/Wan2.2-T2V-A14B.py | 27 ++++++ .../model_inference/Wan2.2-TI2V-5B.py | 8 +- 6 files changed, 175 insertions(+), 9 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 0903f79..6448fbf 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -141,6 +141,8 @@ model_loader_configs = [ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), + (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), + (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 1106669..3262057 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -352,6 +352,7 @@ class WanModel(torch.nn.Module): context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + fused_y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, @@ -365,6 +366,8 @@ class WanModel(torch.nn.Module): x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + if fused_y is not None: + x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) x, (f, h, w) = self.patchify(x) @@ -673,6 +676,7 @@ class WanModelStateDictConverter: "in_dim_control_adapter": 24, } elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": + # Wan-AI/Wan2.2-TI2V-5B config = { "has_image_input": False, "patch_size": [1, 2, 2], @@ -687,6 +691,21 @@ class WanModelStateDictConverter: "eps": 1e-6, "seperated_timestep": True, } + elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": + # Wan-AI/Wan2.2-I2V-A14B + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + } else: config = {} return state_dict, config diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 7ea0e3c..d108ad4 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -226,10 +226,11 @@ class WanVideoPipeline(BasePipeline): self.text_encoder: WanTextEncoder = None self.image_encoder: WanImageEncoder = None self.dit: WanModel = None + self.dit2: WanModel = None self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None - self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), @@ -238,6 +239,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), WanVideoUnit_ImageVaeEmbedder(), + WanVideoUnit_ImageEmbedderNoClip(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -329,6 +331,37 @@ class WanVideoPipeline(BasePipeline): ), vram_limit=vram_limit, ) + if self.dit2 is not None: + dtype = next(iter(self.dit2.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.dit2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) if self.vae is not None: dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( @@ -427,6 +460,10 @@ class WanVideoPipeline(BasePipeline): for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @@ -473,6 +510,9 @@ class WanVideoPipeline(BasePipeline): # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") pipe.dit = model_manager.fetch_model("wan_video_dit") + num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"]) + if num_dits == 2: + pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1] pipe.vae = model_manager.fetch_model("wan_video_vae") pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") @@ -523,6 +563,8 @@ class WanVideoPipeline(BasePipeline): # Classifier-free guidance cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, + # Boundary + boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, @@ -575,8 +617,12 @@ class WanVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # switch high_noise DiT to low_noise DiT + if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: + print("switching to low noise DiT") + self.load_models_to_device(["dit2", "motion_controller", "vace"]) + models["dit"] = models.pop("dit2") timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: @@ -737,7 +783,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.dit.seperated_timestep: + if input_image is None or pipe.image_encoder is None: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -767,6 +813,9 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ def __init__(self): super().__init__( input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -815,6 +864,42 @@ class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): return out1, out2 +class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): + """ + Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"fused_y": y} + + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( @@ -1116,6 +1201,7 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + fused_y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, @@ -1181,11 +1267,13 @@ def model_fn_wan_video( x = torch.concat([x] * context.shape[0], dim=0) if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) - + if dit.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + if fused_y is not None: + x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..9782a2c --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..de9ae5f --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -0,0 +1,27 @@ +import torch +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import snapshot_download + +snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B") + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index b737b47..f41a941 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -1,17 +1,15 @@ import torch from PIL import Image -from diffsynth import save_video, VideoData +from diffsynth import save_video from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import snapshot_download -from diffsynth.models.utils import load_state_dict, hash_state_dict_keys from modelscope import dataset_snapshot_download pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="model_shards/model-*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), ], )