From f1f00c425521209eff96f43b092563be908a172a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 21 Jul 2025 14:47:58 +0800 Subject: [PATCH] support wan2.2 5B I2V --- diffsynth/models/wan_video_dit.py | 20 +++++- diffsynth/pipelines/wan_video_new.py | 66 +++++++++++++++++-- .../model_inference/Wan2.2-TI2V-5B.py | 31 ++++++--- 3 files changed, 99 insertions(+), 18 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 2daf1b4..b7df3ff 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -212,9 +212,16 @@ class DiTBlock(nn.Module): self.gate = GateModule() def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) input_x = modulate(self.norm1(x), shift_msa, scale_msa) x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) x = x + self.cross_attn(self.norm3(x), context) @@ -253,8 +260,12 @@ class Head(nn.Module): self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, t_mod): - shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + scale) + shift)) + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) return x @@ -276,12 +287,14 @@ class WanModel(torch.nn.Module): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, + is_5b: bool = False, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size + self.is_5b = is_5b self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -672,6 +685,7 @@ class WanModelStateDictConverter: "num_heads": 24, "num_layers": 30, "eps": 1e-6, + "is_5b": True, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 91a6f7b..1136403 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -237,6 +237,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedder(), + WanVideoUnit_ImageEmbedder5B(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -736,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None: + if input_image is None or pipe.dit.is_5b: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -764,7 +765,55 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"clip_feature": clip_context, "y": y} - + +class WanVideoUnit_ImageEmbedder5B(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae") + ) + + def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.is_5b: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) + z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + + _, mask2 = self.masks_like([noise.squeeze(0)], zero=True) + latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0) + latents = latents.unsqueeze(0) + + seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2) + if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: + import math + seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size + + return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len} + + @staticmethod + def masks_like(tensor, zero=False, generator=None, p=0.2): + assert isinstance(tensor, list) + out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + + if zero: + if generator is not None: + for u, v in zip(out1, out2): + random_num = torch.rand(1, generator=generator, device=generator.device).item() + if random_num < p: + u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp() + v[:, 0] = torch.zeros_like(v[:, 0]) + else: + u[:, 0] = u[:, 0] + v[:, 0] = v[:, 0] + else: + for u, v in zip(out1, out2): + u[:, 0] = torch.zeros_like(u[:, 0]) + v[:, 0] = torch.zeros_like(v[:, 0]) + + return out1, out2 + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): @@ -1112,9 +1161,16 @@ def model_fn_wan_video( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) - - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) - t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + if dit.is_5b and "mask_5b" in kwargs: + temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten() + temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) + timestep = temp_ts.unsqueeze(0).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index 93ac975..b737b47 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -6,13 +6,6 @@ from modelscope import snapshot_download from diffsynth.models.utils import load_state_dict, hash_state_dict_keys from modelscope import dataset_snapshot_download -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] -) - - pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", @@ -26,9 +19,27 @@ pipe.enable_vram_management() # Text-to-video video = pipe( - prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=True, - height=704, width=1280, + seed=0, tiled=False, + height=704, width=1248, + num_frames=121, ) save_video(video, "video1.mp4", fps=15, quality=5) + +# Image-to-video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + height=704, width=1248, + input_image=input_image, + num_frames=121, +) +save_video(video, "video2.mp4", fps=15, quality=5)