Merge pull request #877 from mi804/wans2v_framepack

support s2v framepack
This commit is contained in:
Zhongjie Duan
2025-09-02 16:54:37 +08:00
committed by GitHub
8 changed files with 220 additions and 56 deletions

View File

@@ -205,7 +205,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|

View File

@@ -205,7 +205,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|

View File

@@ -410,7 +410,6 @@ class WanS2VModel(torch.nn.Module):
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
# TODO: refactor dfs
self.audio_injector = AudioInjector_WAN(
all_modules,
all_modules_names,
@@ -449,7 +448,6 @@ class WanS2VModel(torch.nn.Module):
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
# inject the motion frames token to the hidden states
# TODO: check drop_motion_frames = False
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
if len(mot) > 0:
x = torch.cat([x, mot[0]], dim=1)

View File

@@ -183,6 +183,13 @@ class WanS2VAudioEncoder(torch.nn.Module):
return batch_audio_eb, min_batch_num
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
return audio_embeds
@staticmethod
def state_dict_converter():
return WanS2VAudioEncoderStateDictConverter()

View File

@@ -65,6 +65,9 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
]
self.post_units = [
WanVideoPostUnit_S2V(),
]
self.model_fn = model_fn_wan_video
@@ -391,9 +394,12 @@ class WanVideoPipeline(BasePipeline):
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
# Speech-to-video
input_audio: Optional[str] = None,
input_audio: Optional[np.array] = None,
audio_embeds: Optional[torch.Tensor] = None,
audio_sample_rate: Optional[int] = 16000,
s2v_pose_video: Optional[list[Image.Image]] = None,
s2v_pose_latents: Optional[torch.Tensor] = None,
motion_video: Optional[list[Image.Image]] = None,
# ControlNet
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
@@ -462,7 +468,7 @@ class WanVideoPipeline(BasePipeline):
"motion_bucket_id": motion_bucket_id,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video,
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -498,7 +504,9 @@ class WanVideoPipeline(BasePipeline):
# VACE (TODO: remove it)
if vace_reference_image is not None:
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
# post-denoising, pre-decoding processing logic
for unit in self.post_units:
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Decode
self.load_models_to_device(['vae'])
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
@@ -909,59 +917,91 @@ class WanVideoUnit_S2V(PipelineUnit):
onload_model_names=("audio_encoder", "vae",)
)
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
return {}
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False):
if audio_embeds is not None:
return {"audio_embeds": audio_embeds}
pipe.load_models_to_device(["audio_encoder"])
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True, dtype=pipe.torch_dtype, device=pipe.device)
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
z, fps=16, batch_frames=num_frames - 1, m=0
)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1]
return {"audio_input": audio_embed_bucket}
audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device)
if return_all:
return audio_embeds
else:
return {"audio_embeds": audio_embeds[0]}
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None):
pipe.load_models_to_device(["vae"])
# TODO: may support input motion latents, which related to `drop_motion_frames = False`
motion_frames = 73
lat_motion_frames = (motion_frames + 3) // 4 # 19
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
kwargs = {}
if motion_video is not None and len(motion_video) > 0:
assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
motion_latents = pipe.preprocess_video(motion_video)
kwargs["drop_motion_frames"] = False
else:
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
kwargs["drop_motion_frames"] = True
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_latents": motion_latents}
kwargs.update({"motion_latents": motion_latents})
return kwargs
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False):
if s2v_pose_latents is not None:
return {"s2v_pose_latents": s2v_pose_latents}
if s2v_pose_video is None:
return {"pose_cond": None}
return {"s2v_pose_latents": None}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(s2v_pose_video)
# get num_frames-1 frames
input_video = input_video[:, :, :num_frames]
infer_frames = num_frames - 1
input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats]
# pad if not enough frames
padding_frames = num_frames - input_video.shape[2]
padding_frames = infer_frames * num_repeats - input_video.shape[2]
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
# encode to latents
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"pose_cond": input_latents[:,:,1:]}
input_videos = input_video.chunk(num_repeats, dim=2)
pose_conds = []
for r in range(num_repeats):
cond = input_videos[r]
cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2)
cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
pose_conds.append(cond_latents[:,:,1:])
if return_all:
return pose_conds
else:
return {"s2v_pose_latents": pose_conds[0]}
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None:
if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None:
return inputs_shared, inputs_posi, inputs_nega
input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width")
tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate")
s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video")
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames)
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds)
inputs_posi.update(audio_input_positive)
inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]})
inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]})
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride))
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride))
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video))
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents))
return inputs_shared, inputs_posi, inputs_nega
@staticmethod
def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)):
assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first."
shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames)
height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"]
unit = WanVideoUnit_S2V()
audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True)
pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
pose_latents = None if s2v_pose_video is None else pose_latents
return audio_embeds, pose_latents, len(audio_embeds)
class WanVideoPostUnit_S2V(PipelineUnit):
def __init__(self):
super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames"))
def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames):
if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames:
return {}
latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2)
return {"latents": latents}
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
@@ -1081,9 +1121,10 @@ def model_fn_wan_video(
reference_latents = None,
vace_context = None,
vace_scale = 1.0,
audio_input: Optional[torch.Tensor] = None,
audio_embeds: Optional[torch.Tensor] = None,
motion_latents: Optional[torch.Tensor] = None,
pose_cond: Optional[torch.Tensor] = None,
s2v_pose_latents: Optional[torch.Tensor] = None,
drop_motion_frames: bool = True,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
@@ -1122,15 +1163,16 @@ def model_fn_wan_video(
batch_size=2 if cfg_merge else 1
)
# wan2.2 s2v
if audio_input is not None:
if audio_embeds is not None:
return model_fn_wans2v(
dit=dit,
latents=latents,
timestep=timestep,
context=context,
audio_input=audio_input,
audio_embeds=audio_embeds,
motion_latents=motion_latents,
pose_cond=pose_cond,
s2v_pose_latents=s2v_pose_latents,
drop_motion_frames=drop_motion_frames,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
use_gradient_checkpointing=use_gradient_checkpointing,
use_unified_sequence_parallel=use_unified_sequence_parallel,
@@ -1261,9 +1303,10 @@ def model_fn_wans2v(
latents,
timestep,
context,
audio_input,
audio_embeds,
motion_latents,
pose_cond,
s2v_pose_latents,
drop_motion_frames=True,
use_gradient_checkpointing_offload=False,
use_gradient_checkpointing=False,
use_unified_sequence_parallel=False,
@@ -1280,11 +1323,11 @@ def model_fn_wans2v(
context = dit.text_embedding(context)
# audio encode
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input)
audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds)
# x and pose_cond
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
# x and s2v_pose_latents
s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents))
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
# reference image
@@ -1296,7 +1339,7 @@ def model_fn_wans2v(
# freqs
pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
# motion
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2)
x = x + dit.trainable_cond_mask(mask).to(x.dtype)

View File

@@ -48,7 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|

View File

@@ -48,7 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|

View File

@@ -0,0 +1,116 @@
import torch
from PIL import Image
import librosa
from diffsynth import VideoData, save_video_with_audio
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V
from modelscope import dataset_snapshot_download
def speech_to_video(
prompt,
input_image,
audio_path,
negative_prompt="",
num_clip=None,
audio_sample_rate=16000,
pose_video_path=None,
infer_frames=80,
height=448,
width=832,
num_inference_steps=40,
fps=16, # recommend fixing fps as 16 for s2v
motion_frames=73, # hyperparameter of wan2.2-s2v
save_path=None,
):
# s2v audio input, recommend 16kHz sampling rate
input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate)
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
pipe=pipe,
input_audio=input_audio,
audio_sample_rate=sample_rate,
s2v_pose_video=pose_video,
num_frames=infer_frames + 1,
height=height,
width=width,
fps=fps,
)
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
print(f"Generating {num_repeat} video clips...")
motion_videos = []
video = []
for r in range(num_repeat):
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
current_clip = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt=negative_prompt,
seed=0,
num_frames=infer_frames + 1,
height=height,
width=width,
audio_embeds=audio_embeds[r],
s2v_pose_latents=s2v_pose_latents,
motion_video=motion_videos,
num_inference_steps=num_inference_steps,
)
current_clip = current_clip[-infer_frames:]
if r == 0:
current_clip = current_clip[3:]
overlap_frames_num = min(motion_frames, len(current_clip))
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
video.extend(current_clip)
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
return video
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
],
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_video_dataset",
local_dir="./data/example_video_dataset",
allow_file_pattern=f"wans2v/*",
)
infer_frames = 80 # 4n
height = 448
width = 832
prompt = "a person is singing"
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
video_with_audio = speech_to_video(
prompt=prompt,
input_image=input_image,
audio_path='data/example_video_dataset/wans2v/sing.MP3',
negative_prompt=negative_prompt,
pose_video_path='data/example_video_dataset/wans2v/pose.mp4',
save_path="video_with_audio_full.mp4",
infer_frames=infer_frames,
height=height,
width=width,
)
# num_clip means generating only the first n clips with n * infer_frames frames.
video_with_audio_pose = speech_to_video(
prompt=prompt,
input_image=input_image,
audio_path='data/example_video_dataset/wans2v/sing.MP3',
negative_prompt=negative_prompt,
pose_video_path='data/example_video_dataset/wans2v/pose.mp4',
save_path="video_with_audio_pose_clip_2.mp4",
num_clip=2
)