mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
merge data process to training script
This commit is contained in:
@@ -2,7 +2,7 @@ import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
@@ -29,7 +29,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self, self.pipe, trainable_models,
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||
enable_fp8_training=enable_fp8_training,
|
||||
)
|
||||
@@ -81,9 +81,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
def forward(self, data, inputs=None, return_inputs=False):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
|
||||
if return_inputs: return inputs
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
return loss
|
||||
@@ -123,13 +124,8 @@ if __name__ == "__main__":
|
||||
enable_fp8_training=args.enable_fp8_training,
|
||||
)
|
||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
launch_training_task(
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
launcher_map = {
|
||||
"sft": launch_training_task,
|
||||
"data_process": launch_data_process_task
|
||||
}
|
||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||
|
||||
Reference in New Issue
Block a user