mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
@@ -72,6 +72,9 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
inputs_shared[extra_input] = data[extra_input][0]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if inputs_shared.get("framewise_decoding", False):
|
||||
# WanToDance global model
|
||||
inputs_shared["num_frames"] = 4 * (len(data["video"]) - 1) + 1
|
||||
return inputs_shared
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
@@ -117,6 +120,7 @@ def wan_parser():
|
||||
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||
parser.add_argument("--framewise_decoding", default=False, action="store_true", help="Enable it if this model is a WanToDance global model.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -140,12 +144,13 @@ if __name__ == "__main__":
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
num_frames=args.num_frames,
|
||||
time_division_factor=4,
|
||||
time_division_remainder=1,
|
||||
time_division_factor=4 if not args.framewise_decoding else 1,
|
||||
time_division_remainder=1 if not args.framewise_decoding else 0,
|
||||
),
|
||||
special_operator_map={
|
||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
|
||||
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),
|
||||
"wantodance_music_path": ToAbsolutePath(args.dataset_base_path),
|
||||
}
|
||||
)
|
||||
model = WanTrainingModule(
|
||||
|
||||
Reference in New Issue
Block a user