Merge pull request #892 from modelscope/dev2-dzj

refine training framework
This commit is contained in:
Zhongjie Duan
2025-09-04 15:53:52 +08:00
committed by GitHub
6 changed files with 105 additions and 285 deletions

View File

@@ -1,4 +1,6 @@
import imageio, os, torch, warnings, torchvision, argparse, json
from ..utils import ModelConfig
from ..models.utils import load_state_dict
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
@@ -424,7 +426,53 @@ class DiffusionTrainingModule(torch.nn.Module):
if isinstance(data[key], torch.Tensor):
data[key] = data[key].to(device)
return data
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
return model_configs
def switch_pipe_to_training_mode(
self,
pipe,
trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
enable_fp8_training=False,
):
# Scheduler
pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Enable FP8 if pipeline supports
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
class ModelLogger:
@@ -472,14 +520,26 @@ def launch_training_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 8,
save_steps: int = None,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
gradient_accumulation_steps = args.gradient_accumulation_steps
find_unused_parameters = args.find_unused_parameters
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
@@ -509,8 +569,12 @@ def launch_data_process_task(
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator()
model, dataloader = accelerator.prepare(model, dataloader)
@@ -520,7 +584,7 @@ def launch_data_process_task(
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
data = model(data, return_inputs=True)
torch.save(data, save_path)
@@ -623,4 +687,5 @@ def qwen_image_parser():
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
return parser