support wan2.2-animate-14b

This commit is contained in:
Artiprocher
2025-09-30 12:45:56 +08:00
parent 0d6de58af9
commit a36f2f6032
15 changed files with 999 additions and 33 deletions

View File

@@ -21,6 +21,7 @@ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_animate_adapter import WanAnimateAdapter
from ..schedulers.flow_match import FlowMatchScheduler
from ..prompters import WanPrompter
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
@@ -45,8 +46,9 @@ class WanVideoPipeline(BasePipeline):
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.vace2: VaceWanModel = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2")
self.animate_adapter: WanAnimateAdapter = None
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
@@ -62,6 +64,10 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_FunCameraControl(),
WanVideoUnit_SpeedControl(),
WanVideoUnit_VACE(),
WanVideoPostUnit_AnimateVideoSplit(),
WanVideoPostUnit_AnimatePoseLatents(),
WanVideoPostUnit_AnimateFacePixelValues(),
WanVideoPostUnit_AnimateInpaint(),
WanVideoUnit_UnifiedSequenceParallel(),
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
@@ -70,13 +76,34 @@ class WanVideoPipeline(BasePipeline):
WanVideoPostUnit_S2V(),
]
self.model_fn = model_fn_wan_video
def load_lora(self, module, path, alpha=1):
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
state_dict=None,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def training_loss(self, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
@@ -359,12 +386,13 @@ class WanVideoPipeline(BasePipeline):
pipe.vae = model_manager.fetch_model("wan_video_vae")
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace")
vace = model_manager.fetch_model("wan_video_vace", index=2)
if isinstance(vace, list):
pipe.vace, pipe.vace2 = vace
else:
pipe.vace = vace
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
pipe.animate_adapter = model_manager.fetch_model("wan_video_animate_adapter")
# Size division factor
if pipe.vae is not None:
@@ -417,6 +445,11 @@ class WanVideoPipeline(BasePipeline):
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
# Animate
animate_pose_video: Optional[list[Image.Image]] = None,
animate_face_video: Optional[list[Image.Image]] = None,
animate_inpaint_video: Optional[list[Image.Image]] = None,
animate_mask_video: Optional[list[Image.Image]] = None,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
@@ -474,6 +507,7 @@ class WanVideoPipeline(BasePipeline):
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -508,7 +542,7 @@ class WanVideoPipeline(BasePipeline):
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
# VACE (TODO: remove it)
if vace_reference_image is not None:
if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
# post-denoising, pre-decoding processing logic
for unit in self.post_units:
@@ -1021,6 +1055,95 @@ class WanVideoPostUnit_S2V(PipelineUnit):
return {"latents": latents}
class WanVideoPostUnit_AnimateVideoSplit(PipelineUnit):
def __init__(self):
super().__init__(input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"))
def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video):
if input_video is None:
return {}
if animate_pose_video is not None:
animate_pose_video = animate_pose_video[:len(input_video) - 4]
if animate_face_video is not None:
animate_face_video = animate_face_video[:len(input_video) - 4]
if animate_inpaint_video is not None:
animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4]
if animate_mask_video is not None:
animate_mask_video = animate_mask_video[:len(input_video) - 4]
return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video}
class WanVideoPostUnit_AnimatePoseLatents(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride):
if animate_pose_video is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
animate_pose_video = pipe.preprocess_video(animate_pose_video)
pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"pose_latents": pose_latents}
class WanVideoPostUnit_AnimateFacePixelValues(PipelineUnit):
def __init__(self):
super().__init__(take_over=True)
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("animate_face_video", None) is None:
return {}
inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"])
inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1
return inputs_shared, inputs_posi, inputs_nega
class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
msk = mask_pixel_values.clone()
msk[:, :mask_len] = 1
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
return msk
def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride):
if animate_inpaint_video is None or animate_mask_video is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
bg_pixel_values = pipe.preprocess_video(animate_inpaint_video)
y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device)
_, lat_t, lat_h, lat_w = y_reft.shape
ref_pixel_values = pipe.preprocess_video([input_image])
ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device)
y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device)
mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0)
mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest')
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device)
y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device)
y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0)
return {"y": y}
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
@@ -1131,6 +1254,7 @@ def model_fn_wan_video(
dit: WanModel,
motion_controller: WanMotionControllerModel = None,
vace: VaceWanModel = None,
animate_adapter: WanAnimateAdapter = None,
latents: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
@@ -1146,6 +1270,8 @@ def model_fn_wan_video(
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
pose_latents=None,
face_pixel_values=None,
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
cfg_merge: bool = False,
@@ -1236,9 +1362,16 @@ def model_fn_wan_video(
if clip_feature is not None and dit.require_clip_embedding:
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
# Add camera control
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
# Camera control
x = dit.patchify(x, control_camera_latents_input)
# Animate
x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values)
# Patchify
f, h, w = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
# Reference image
if reference_latents is not None:
@@ -1283,6 +1416,7 @@ def model_fn_wan_video(
return custom_forward
for block_id, block in enumerate(dit.blocks):
# Block
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
@@ -1298,12 +1432,18 @@ def model_fn_wan_video(
)
else:
x = block(x, context, t_mod, freqs)
# VACE
if vace_context is not None and block_id in vace.vace_layers_mapping:
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
x = x + current_vace_hint * vace_scale
# Animate
if pose_latents is not None and face_pixel_values is not None:
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
if tea_cache is not None:
tea_cache.store(x)