mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
wan-refactor
This commit is contained in:
@@ -13,14 +13,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
# Extra inputs
|
||||
input_contains_input_image=False,
|
||||
input_contains_end_image=False,
|
||||
input_contains_control_video=False,
|
||||
input_contains_reference_image=False,
|
||||
input_contains_vace_video=False,
|
||||
input_contains_vace_reference_image=False,
|
||||
input_contains_motion_bucket_id=False,
|
||||
extra_inputs=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
@@ -51,13 +44,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.input_contains_input_image = input_contains_input_image
|
||||
self.input_contains_end_image = input_contains_end_image
|
||||
self.input_contains_control_video = input_contains_control_video
|
||||
self.input_contains_reference_image = input_contains_reference_image
|
||||
self.input_contains_vace_video = input_contains_vace_video
|
||||
self.input_contains_vace_reference_image = input_contains_vace_reference_image
|
||||
self.input_contains_motion_bucket_id = input_contains_motion_bucket_id
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
@@ -85,13 +72,13 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
if self.input_contains_input_image: inputs_shared["input_image"] = data["video"][0]
|
||||
if self.input_contains_end_image: inputs_shared["end_image"] = data["video"][-1]
|
||||
if self.input_contains_control_video: inputs_shared["control_video"] = data["control_video"]
|
||||
if self.input_contains_reference_image: inputs_shared["reference_image"] = data["reference_image"]
|
||||
if self.input_contains_vace_video: inputs_shared["vace_video"] = data["vace_video"]
|
||||
if self.input_contains_vace_reference_image: inputs_shared["vace_reference_image"] = data["vace_reference_image"]
|
||||
if self.input_contains_motion_bucket_id: inputs_shared["motion_bucket_id"] = data["motion_bucket_id"]
|
||||
for extra_input in self.extra_inputs:
|
||||
if extra_input == "input_image":
|
||||
inputs_shared["input_image"] = data["video"][0]
|
||||
elif extra_input == "end_image":
|
||||
inputs_shared["end_image"] = data["video"][-1]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
@@ -118,12 +105,6 @@ if __name__ == "__main__":
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
input_contains_input_image=args.input_contains_input_image,
|
||||
input_contains_end_image=args.input_contains_end_image,
|
||||
input_contains_control_video=args.input_contains_control_video,
|
||||
input_contains_reference_image=args.input_contains_reference_image,
|
||||
input_contains_vace_video=args.input_contains_vace_video,
|
||||
input_contains_vace_reference_image=args.input_contains_vace_reference_image,
|
||||
input_contains_motion_bucket_id=args.input_contains_motion_bucket_id,
|
||||
extra_inputs=args.extra_inputs,
|
||||
)
|
||||
launch_training_task(model, dataset, args=args)
|
||||
|
||||
Reference in New Issue
Block a user