mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
refine code
This commit is contained in:
@@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
max_timestep_boundary=1.0,
|
||||
min_timestep_boundary=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
@@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.max_timestep_boundary = max_timestep_boundary
|
||||
self.min_timestep_boundary = min_timestep_boundary
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
@@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
"cfg_merge": False,
|
||||
"vace_scale": 1,
|
||||
"max_timestep_boundary": self.max_timestep_boundary,
|
||||
"min_timestep_boundary": self.min_timestep_boundary,
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
@@ -106,6 +112,8 @@ if __name__ == "__main__":
|
||||
lora_rank=args.lora_rank,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
max_timestep_boundary=args.max_timestep_boundary,
|
||||
min_timestep_boundary=args.min_timestep_boundary,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
|
||||
Reference in New Issue
Block a user