support wan2.2 5B I2V

This commit is contained in:
mi804
2025-07-21 14:47:58 +08:00
parent 55951590f5
commit f1f00c4255
3 changed files with 99 additions and 18 deletions

View File

@@ -237,6 +237,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_ImageEmbedder5B(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_FunCameraControl(),
@@ -736,7 +737,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None:
if input_image is None or pipe.dit.is_5b:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
@@ -764,7 +765,55 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_ImageEmbedder5B(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae")
)
def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.is_5b:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device)
z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
_, mask2 = self.masks_like([noise.squeeze(0)], zero=True)
latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0)
latents = latents.unsqueeze(0)
seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2)
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
import math
seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size
return {"latents": latents, "mask_5b": mask2[0].unsqueeze(0), "seq_len": seq_len}
@staticmethod
def masks_like(tensor, zero=False, generator=None, p=0.2):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp()
v[:, 0] = torch.zeros_like(v[:, 0])
else:
u[:, 0] = u[:, 0]
v[:, 0] = v[:, 0]
else:
for u, v in zip(out1, out2):
u[:, 0] = torch.zeros_like(u[:, 0])
v[:, 0] = torch.zeros_like(v[:, 0])
return out1, out2
class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self):
@@ -1112,9 +1161,16 @@ def model_fn_wan_video(
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if dit.is_5b and "mask_5b" in kwargs:
temp_ts = (kwargs["mask_5b"][0][0][:, ::2, ::2] * timestep).flatten()
temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep])
timestep = temp_ts.unsqueeze(0).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"])))
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)