wan-refactor

This commit is contained in:
Artiprocher
2025-06-13 13:04:35 +08:00
parent 7e6a3c7897
commit 8dd24169cc
47 changed files with 810 additions and 310 deletions

View File

@@ -13,14 +13,7 @@ class WanTrainingModule(DiffusionTrainingModule):
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
# Extra inputs
input_contains_input_image=False,
input_contains_end_image=False,
input_contains_control_video=False,
input_contains_reference_image=False,
input_contains_vace_video=False,
input_contains_vace_reference_image=False,
input_contains_motion_bucket_id=False,
extra_inputs=None,
):
super().__init__()
# Load models
@@ -51,13 +44,7 @@ class WanTrainingModule(DiffusionTrainingModule):
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.input_contains_input_image = input_contains_input_image
self.input_contains_end_image = input_contains_end_image
self.input_contains_control_video = input_contains_control_video
self.input_contains_reference_image = input_contains_reference_image
self.input_contains_vace_video = input_contains_vace_video
self.input_contains_vace_reference_image = input_contains_vace_reference_image
self.input_contains_motion_bucket_id = input_contains_motion_bucket_id
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
def forward_preprocess(self, data):
@@ -85,13 +72,13 @@ class WanTrainingModule(DiffusionTrainingModule):
}
# Extra inputs
if self.input_contains_input_image: inputs_shared["input_image"] = data["video"][0]
if self.input_contains_end_image: inputs_shared["end_image"] = data["video"][-1]
if self.input_contains_control_video: inputs_shared["control_video"] = data["control_video"]
if self.input_contains_reference_image: inputs_shared["reference_image"] = data["reference_image"]
if self.input_contains_vace_video: inputs_shared["vace_video"] = data["vace_video"]
if self.input_contains_vace_reference_image: inputs_shared["vace_reference_image"] = data["vace_reference_image"]
if self.input_contains_motion_bucket_id: inputs_shared["motion_bucket_id"] = data["motion_bucket_id"]
for extra_input in self.extra_inputs:
if extra_input == "input_image":
inputs_shared["input_image"] = data["video"][0]
elif extra_input == "end_image":
inputs_shared["end_image"] = data["video"][-1]
else:
inputs_shared[extra_input] = data[extra_input]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
@@ -118,12 +105,6 @@ if __name__ == "__main__":
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
input_contains_input_image=args.input_contains_input_image,
input_contains_end_image=args.input_contains_end_image,
input_contains_control_video=args.input_contains_control_video,
input_contains_reference_image=args.input_contains_reference_image,
input_contains_vace_video=args.input_contains_vace_video,
input_contains_vace_reference_image=args.input_contains_vace_reference_image,
input_contains_motion_bucket_id=args.input_contains_motion_bucket_id,
extra_inputs=args.extra_inputs,
)
launch_training_task(model, dataset, args=args)