mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan2.2-animate-14b
This commit is contained in:
@@ -21,6 +21,7 @@ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from ..prompters import WanPrompter
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||
@@ -45,8 +46,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.vace2: VaceWanModel = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2")
|
||||
self.animate_adapter: WanAnimateAdapter = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
@@ -62,6 +64,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
WanVideoUnit_SpeedControl(),
|
||||
WanVideoUnit_VACE(),
|
||||
WanVideoPostUnit_AnimateVideoSplit(),
|
||||
WanVideoPostUnit_AnimatePoseLatents(),
|
||||
WanVideoPostUnit_AnimateFacePixelValues(),
|
||||
WanVideoPostUnit_AnimateInpaint(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
WanVideoUnit_TeaCache(),
|
||||
WanVideoUnit_CfgMerger(),
|
||||
@@ -70,13 +76,34 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoPostUnit_S2V(),
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
|
||||
|
||||
def load_lora(self, module, path, alpha=1):
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
state_dict=None,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora = state_dict
|
||||
if hotload:
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
lora_a_name = f'{name}.lora_A.default.weight'
|
||||
lora_b_name = f'{name}.lora_B.default.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
else:
|
||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
loader.load(module, lora, alpha=alpha)
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
||||
@@ -359,12 +386,13 @@ class WanVideoPipeline(BasePipeline):
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||
vace = model_manager.fetch_model("wan_video_vace", index=2)
|
||||
if isinstance(vace, list):
|
||||
pipe.vace, pipe.vace2 = vace
|
||||
else:
|
||||
pipe.vace = vace
|
||||
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||
pipe.animate_adapter = model_manager.fetch_model("wan_video_animate_adapter")
|
||||
|
||||
# Size division factor
|
||||
if pipe.vae is not None:
|
||||
@@ -417,6 +445,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
vace_video_mask: Optional[Image.Image] = None,
|
||||
vace_reference_image: Optional[Image.Image] = None,
|
||||
vace_scale: Optional[float] = 1.0,
|
||||
# Animate
|
||||
animate_pose_video: Optional[list[Image.Image]] = None,
|
||||
animate_face_video: Optional[list[Image.Image]] = None,
|
||||
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
||||
animate_mask_video: Optional[list[Image.Image]] = None,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
@@ -474,6 +507,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -508,7 +542,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
||||
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
|
||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||
# post-denoising, pre-decoding processing logic
|
||||
for unit in self.post_units:
|
||||
@@ -1021,6 +1055,95 @@ class WanVideoPostUnit_S2V(PipelineUnit):
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimateVideoSplit(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"))
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video):
|
||||
if input_video is None:
|
||||
return {}
|
||||
if animate_pose_video is not None:
|
||||
animate_pose_video = animate_pose_video[:len(input_video) - 4]
|
||||
if animate_face_video is not None:
|
||||
animate_face_video = animate_face_video[:len(input_video) - 4]
|
||||
if animate_inpaint_video is not None:
|
||||
animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4]
|
||||
if animate_mask_video is not None:
|
||||
animate_mask_video = animate_mask_video[:len(input_video) - 4]
|
||||
return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video}
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimatePoseLatents(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride):
|
||||
if animate_pose_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
animate_pose_video = pipe.preprocess_video(animate_pose_video)
|
||||
pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"pose_latents": pose_latents}
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimateFacePixelValues(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("animate_face_video", None) is None:
|
||||
return {}
|
||||
inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"])
|
||||
inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
||||
if mask_pixel_values is None:
|
||||
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
||||
else:
|
||||
msk = mask_pixel_values.clone()
|
||||
msk[:, :mask_len] = 1
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
return msk
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride):
|
||||
if animate_inpaint_video is None or animate_mask_video is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
bg_pixel_values = pipe.preprocess_video(animate_inpaint_video)
|
||||
y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
_, lat_t, lat_h, lat_w = y_reft.shape
|
||||
|
||||
ref_pixel_values = pipe.preprocess_video([input_image])
|
||||
ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device)
|
||||
y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device)
|
||||
|
||||
mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0)
|
||||
mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
|
||||
mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest')
|
||||
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
|
||||
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device)
|
||||
|
||||
y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device)
|
||||
y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0)
|
||||
return {"y": y}
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -1131,6 +1254,7 @@ def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
animate_adapter: WanAnimateAdapter = None,
|
||||
latents: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
@@ -1146,6 +1270,8 @@ def model_fn_wan_video(
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
pose_latents=None,
|
||||
face_pixel_values=None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sliding_window_stride: Optional[int] = None,
|
||||
cfg_merge: bool = False,
|
||||
@@ -1236,9 +1362,16 @@ def model_fn_wan_video(
|
||||
if clip_feature is not None and dit.require_clip_embedding:
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
# Add camera control
|
||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
# Camera control
|
||||
x = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
# Animate
|
||||
x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values)
|
||||
|
||||
# Patchify
|
||||
f, h, w = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
@@ -1283,6 +1416,7 @@ def model_fn_wan_video(
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
# Block
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
@@ -1298,12 +1432,18 @@ def model_fn_wan_video(
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
# VACE
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||
x = x + current_vace_hint * vace_scale
|
||||
|
||||
# Animate
|
||||
if pose_latents is not None and face_pixel_values is not None:
|
||||
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user