move training code to base trainer

This commit is contained in:
Artiprocher
2025-09-03 12:03:49 +08:00
parent 958ebf1352
commit cb8de6be1b
5 changed files with 79 additions and 129 deletions

View File

@@ -1,4 +1,6 @@
import imageio, os, torch, warnings, torchvision, argparse, json
from ..utils import ModelConfig
from ..models.utils import load_state_dict
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
@@ -424,7 +426,53 @@ class DiffusionTrainingModule(torch.nn.Module):
if isinstance(data[key], torch.Tensor):
data[key] = data[key].to(device)
return data
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
return model_configs
def switch_pipe_to_training_mode(
self,
pipe,
trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
enable_fp8_training=False,
):
# Scheduler
pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Enable FP8 if pipeline supports
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
class ModelLogger:

View File

@@ -20,37 +20,16 @@ class FluxTrainingModule(DiffusionTrainingModule):
):
super().__init__()
# Load models
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
# Reset training scheduler
self.pipe.scheduler.set_timesteps(1000, training=True)
# Training mode
self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=False,
)
# Freeze untrainable models
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(self.pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload

View File

@@ -22,46 +22,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
):
super().__init__()
# Load models
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
# Training mode
self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=enable_fp8_training,
)
# Enable FP8
if enable_fp8_training:
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
# Reset training scheduler (do it in each training step)
self.pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(self.pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank,
upcast_dtype=self.pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload

View File

@@ -22,46 +22,18 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
):
super().__init__()
# Load models
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
# Training mode
self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=enable_fp8_training,
)
# Enable FP8
if enable_fp8_training:
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
# Reset training scheduler (do it in each training step)
self.pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(self.pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank,
upcast_dtype=self.pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload

View File

@@ -21,37 +21,16 @@ class WanTrainingModule(DiffusionTrainingModule):
):
super().__init__()
# Load models
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
model_configs += [ModelConfig(path=path) for path in model_paths]
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
# Reset training scheduler
self.pipe.scheduler.set_timesteps(1000, training=True)
# Training mode
self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=False,
)
# Freeze untrainable models
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Add LoRA to the base models
if lora_base_model is not None:
model = self.add_lora_to_model(
getattr(self.pipe, lora_base_model),
target_modules=lora_target_modules.split(","),
lora_rank=lora_rank
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(self.pipe, lora_base_model, model)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload