support wan2.2 A14B I2V&T2V

This commit is contained in:
mi804
2025-07-25 17:09:53 +08:00
parent 3aed244c6f
commit 9015d08927
6 changed files with 175 additions and 9 deletions

View File

@@ -226,10 +226,11 @@ class WanVideoPipeline(BasePipeline):
self.text_encoder: WanTextEncoder = None
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.dit2: WanModel = None
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
@@ -238,6 +239,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_ImageVaeEmbedder(),
WanVideoUnit_ImageEmbedderNoClip(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_FunCameraControl(),
@@ -329,6 +331,37 @@ class WanVideoPipeline(BasePipeline):
),
vram_limit=vram_limit,
)
if self.dit2 is not None:
dtype = next(iter(self.dit2.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
@@ -427,6 +460,10 @@ class WanVideoPipeline(BasePipeline):
for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
if self.dit2 is not None:
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True
@@ -473,6 +510,9 @@ class WanVideoPipeline(BasePipeline):
# Load models
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"])
if num_dits == 2:
pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1]
pipe.vae = model_manager.fetch_model("wan_video_vae")
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
@@ -523,6 +563,8 @@ class WanVideoPipeline(BasePipeline):
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Boundary
boundary: Optional[float] = 0.875,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
@@ -575,8 +617,12 @@ class WanVideoPipeline(BasePipeline):
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# switch high_noise DiT to low_noise DiT
if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps:
print("switching to low noise DiT")
self.load_models_to_device(["dit2", "motion_controller", "vace"])
models["dit"] = models.pop("dit2")
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
@@ -737,7 +783,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or pipe.dit.seperated_timestep:
if input_image is None or pipe.image_encoder is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
@@ -767,6 +813,9 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
"""
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
"""
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
@@ -815,6 +864,42 @@ class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
return out1, out2
class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit):
"""
Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B.
"""
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae")
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"fused_y": y}
class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self):
super().__init__(
@@ -1116,6 +1201,7 @@ def model_fn_wan_video(
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
fused_y: Optional[torch.Tensor] = None,
reference_latents = None,
vace_context = None,
vace_scale = 1.0,
@@ -1181,11 +1267,13 @@ def model_fn_wan_video(
x = torch.concat([x] * context.shape[0], dim=0)
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
if dit.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
if fused_y is not None:
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
# Add camera control
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)