mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
move training code to base trainer
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import imageio, os, torch, warnings, torchvision, argparse, json
|
||||
from ..utils import ModelConfig
|
||||
from ..models.utils import load_state_dict
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
@@ -424,7 +426,53 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
data[key] = data[key].to(device)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
|
||||
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
||||
return model_configs
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
pipe,
|
||||
trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
|
||||
enable_fp8_training=False,
|
||||
):
|
||||
# Scheduler
|
||||
pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Enable FP8 if pipeline supports
|
||||
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
|
||||
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None:
|
||||
model = self.add_lora_to_model(
|
||||
getattr(pipe, lora_base_model),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=pipe.torch_dtype,
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(pipe, lora_base_model, model)
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
|
||||
Reference in New Issue
Block a user