move training code to base trainer

This commit is contained in:
Artiprocher
2025-09-03 12:03:49 +08:00
parent 958ebf1352
commit cb8de6be1b
5 changed files with 79 additions and 129 deletions

View File

@@ -1,4 +1,6 @@
import imageio, os, torch, warnings, torchvision, argparse, json
from ..utils import ModelConfig
from ..models.utils import load_state_dict
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
@@ -424,7 +426,53 @@ class DiffusionTrainingModule(torch.nn.Module):
if isinstance(data[key], torch.Tensor):
data[key] = data[key].to(device)
return data
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
return model_configs
def switch_pipe_to_training_mode(
self,
pipe,
trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
enable_fp8_training=False,
):
# Scheduler
pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Enable FP8 if pipeline supports
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
class ModelLogger: