Support WanToDance (#1361)

* support wantodance

* update docs

* bugfix
This commit is contained in:
Zhongjie Duan
2026-03-20 16:40:35 +08:00
committed by GitHub
parent ba0626e38f
commit 52ba5d414e
22 changed files with 1210 additions and 13 deletions

View File

@@ -72,6 +72,9 @@ class WanTrainingModule(DiffusionTrainingModule):
inputs_shared[extra_input] = data[extra_input][0]
else:
inputs_shared[extra_input] = data[extra_input]
if inputs_shared.get("framewise_decoding", False):
# WanToDance global model
inputs_shared["num_frames"] = 4 * (len(data["video"]) - 1) + 1
return inputs_shared
def get_pipeline_inputs(self, data):
@@ -117,6 +120,7 @@ def wan_parser():
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
parser.add_argument("--framewise_decoding", default=False, action="store_true", help="Enable it if this model is a WanToDance global model.")
return parser
@@ -140,12 +144,13 @@ if __name__ == "__main__":
height_division_factor=16,
width_division_factor=16,
num_frames=args.num_frames,
time_division_factor=4,
time_division_remainder=1,
time_division_factor=4 if not args.framewise_decoding else 1,
time_division_remainder=1 if not args.framewise_decoding else 0,
),
special_operator_map={
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),
"wantodance_music_path": ToAbsolutePath(args.dataset_base_path),
}
)
model = WanTrainingModule(