support wan2.2 5B I2V

This commit is contained in:
mi804
2025-07-21 14:47:58 +08:00
parent 55951590f5
commit f1f00c4255
3 changed files with 99 additions and 18 deletions

View File

@@ -212,9 +212,16 @@ class DiTBlock(nn.Module):
self.gate = GateModule() self.gate = GateModule()
def forward(self, x, context, t_mod, freqs): def forward(self, x, context, t_mod, freqs):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
# msa: multi-head self-attention mlp: multi-layer perceptron # msa: multi-head self-attention mlp: multi-layer perceptron
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
input_x = modulate(self.norm1(x), shift_msa, scale_msa) input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context) x = x + self.cross_attn(self.norm3(x), context)
@@ -253,8 +260,12 @@ class Head(nn.Module):
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, t_mod): def forward(self, x, t_mod):
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) if len(t_mod.shape) == 3:
x = (self.head(self.norm(x) * (1 + scale) + shift)) shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
else:
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + scale) + shift))
return x return x
@@ -276,12 +287,14 @@ class WanModel(torch.nn.Module):
has_ref_conv: bool = False, has_ref_conv: bool = False,
add_control_adapter: bool = False, add_control_adapter: bool = False,
in_dim_control_adapter: int = 24, in_dim_control_adapter: int = 24,
is_5b: bool = False,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.freq_dim = freq_dim self.freq_dim = freq_dim
self.has_image_input = has_image_input self.has_image_input = has_image_input
self.patch_size = patch_size self.patch_size = patch_size
self.is_5b = is_5b
self.patch_embedding = nn.Conv3d( self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size) in_dim, dim, kernel_size=patch_size, stride=patch_size)
@@ -672,6 +685,7 @@ class WanModelStateDictConverter:
"num_heads": 24, "num_heads": 24,
"num_layers": 30, "num_layers": 30,
"eps": 1e-6, "eps": 1e-6,
"is_5b": True,
} }
else: else:
config = {} config = {}

View File

@@ -237,6 +237,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(), WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(), WanVideoUnit_ImageEmbedder(),
WanVideoUnit_ImageEmbedder5B(),
WanVideoUnit_FunControl(), WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(), WanVideoUnit_FunReference(),
WanVideoUnit_FunCameraControl(), WanVideoUnit_FunCameraControl(),
@@ -736,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
) )
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None: if input_image is None or pipe.dit.is_5b:
return {} return {}
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
@@ -765,6 +766,54 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
return {"clip_feature": clip_context, "y": y} return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_ImageEmbedder5B(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae")
)
def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.is_5b:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device)
z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
_, mask2 = self.masks_like([noise.squeeze(0)], zero=True)
latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0)
latents = latents.unsqueeze(0)
seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2)
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
import math
seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size
return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len}
@staticmethod
def masks_like(tensor, zero=False, generator=None, p=0.2):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp()
v[:, 0] = torch.zeros_like(v[:, 0])
else:
u[:, 0] = u[:, 0]
v[:, 0] = v[:, 0]
else:
for u, v in zip(out1, out2):
u[:, 0] = torch.zeros_like(u[:, 0])
v[:, 0] = torch.zeros_like(v[:, 0])
return out1, out2
class WanVideoUnit_FunControl(PipelineUnit): class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self): def __init__(self):
@@ -1113,8 +1162,15 @@ def model_fn_wan_video(
get_sequence_parallel_world_size, get_sequence_parallel_world_size,
get_sp_group) get_sp_group)
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) if dit.is_5b and "mask_5b" in kwargs:
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten()
temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep])
timestep = temp_ts.unsqueeze(0).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"])))
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if motion_bucket_id is not None and motion_controller is not None: if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context) context = dit.text_embedding(context)

View File

@@ -6,13 +6,6 @@ from modelscope import snapshot_download
from diffsynth.models.utils import load_state_dict, hash_state_dict_keys from diffsynth.models.utils import load_state_dict, hash_state_dict_keys
from modelscope import dataset_snapshot_download from modelscope import dataset_snapshot_download
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
)
pipe = WanVideoPipeline.from_pretrained( pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device="cuda", device="cuda",
@@ -26,9 +19,27 @@ pipe.enable_vram_management()
# Text-to-video # Text-to-video
video = pipe( video = pipe(
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活", prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走", negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True, seed=0, tiled=False,
height=704, width=1280, height=704, width=1248,
num_frames=121,
) )
save_video(video, "video1.mp4", fps=15, quality=5) save_video(video, "video1.mp4", fps=15, quality=5)
# Image-to-video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
)
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704))
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=704, width=1248,
input_image=input_image,
num_frames=121,
)
save_video(video, "video2.mp4", fps=15, quality=5)