mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
move training code to base trainer
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
import imageio, os, torch, warnings, torchvision, argparse, json
|
import imageio, os, torch, warnings, torchvision, argparse, json
|
||||||
|
from ..utils import ModelConfig
|
||||||
|
from ..models.utils import load_state_dict
|
||||||
from peft import LoraConfig, inject_adapter_in_model
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -426,6 +428,52 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def parse_model_configs(self, model_paths, model_id_with_origin_paths, enable_fp8_training=False):
|
||||||
|
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
||||||
|
model_configs = []
|
||||||
|
if model_paths is not None:
|
||||||
|
model_paths = json.loads(model_paths)
|
||||||
|
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
||||||
|
if model_id_with_origin_paths is not None:
|
||||||
|
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||||
|
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
||||||
|
return model_configs
|
||||||
|
|
||||||
|
|
||||||
|
def switch_pipe_to_training_mode(
|
||||||
|
self,
|
||||||
|
pipe,
|
||||||
|
trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=None,
|
||||||
|
enable_fp8_training=False,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
# Freeze untrainable models
|
||||||
|
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||||
|
|
||||||
|
# Enable FP8 if pipeline supports
|
||||||
|
if enable_fp8_training and hasattr(pipe, "_enable_fp8_lora_training"):
|
||||||
|
pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Add LoRA to the base models
|
||||||
|
if lora_base_model is not None:
|
||||||
|
model = self.add_lora_to_model(
|
||||||
|
getattr(pipe, lora_base_model),
|
||||||
|
target_modules=lora_target_modules.split(","),
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
upcast_dtype=pipe.torch_dtype,
|
||||||
|
)
|
||||||
|
if lora_checkpoint is not None:
|
||||||
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
|
if len(load_result[1]) > 0:
|
||||||
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
|
||||||
class ModelLogger:
|
class ModelLogger:
|
||||||
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||||
|
|||||||
@@ -20,36 +20,15 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
model_configs = []
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
|
||||||
if model_paths is not None:
|
|
||||||
model_paths = json.loads(model_paths)
|
|
||||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
|
||||||
if model_id_with_origin_paths is not None:
|
|
||||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
|
||||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
|
||||||
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||||
|
|
||||||
# Reset training scheduler
|
# Training mode
|
||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
self.switch_pipe_to_training_mode(
|
||||||
|
self, self.pipe, trainable_models,
|
||||||
# Freeze untrainable models
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
enable_fp8_training=False,
|
||||||
|
)
|
||||||
# Add LoRA to the base models
|
|
||||||
if lora_base_model is not None:
|
|
||||||
model = self.add_lora_to_model(
|
|
||||||
getattr(self.pipe, lora_base_model),
|
|
||||||
target_modules=lora_target_modules.split(","),
|
|
||||||
lora_rank=lora_rank
|
|
||||||
)
|
|
||||||
if lora_checkpoint is not None:
|
|
||||||
state_dict = load_state_dict(lora_checkpoint)
|
|
||||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
|
||||||
load_result = model.load_state_dict(state_dict, strict=False)
|
|
||||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
|
||||||
if len(load_result[1]) > 0:
|
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
|
||||||
setattr(self.pipe, lora_base_model, model)
|
|
||||||
|
|
||||||
# Store other configs
|
# Store other configs
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
|||||||
@@ -22,45 +22,17 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
|
||||||
model_configs = []
|
|
||||||
if model_paths is not None:
|
|
||||||
model_paths = json.loads(model_paths)
|
|
||||||
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
|
||||||
if model_id_with_origin_paths is not None:
|
|
||||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
|
||||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
|
||||||
|
|
||||||
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||||
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
||||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
||||||
|
|
||||||
# Enable FP8
|
# Training mode
|
||||||
if enable_fp8_training:
|
self.switch_pipe_to_training_mode(
|
||||||
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
self, self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||||
# Reset training scheduler (do it in each training step)
|
enable_fp8_training=enable_fp8_training,
|
||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
)
|
||||||
|
|
||||||
# Freeze untrainable models
|
|
||||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
|
||||||
|
|
||||||
# Add LoRA to the base models
|
|
||||||
if lora_base_model is not None:
|
|
||||||
model = self.add_lora_to_model(
|
|
||||||
getattr(self.pipe, lora_base_model),
|
|
||||||
target_modules=lora_target_modules.split(","),
|
|
||||||
lora_rank=lora_rank,
|
|
||||||
upcast_dtype=self.pipe.torch_dtype,
|
|
||||||
)
|
|
||||||
if lora_checkpoint is not None:
|
|
||||||
state_dict = load_state_dict(lora_checkpoint)
|
|
||||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
|
||||||
load_result = model.load_state_dict(state_dict, strict=False)
|
|
||||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
|
||||||
if len(load_result[1]) > 0:
|
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
|
||||||
setattr(self.pipe, lora_base_model, model)
|
|
||||||
|
|
||||||
# Store other configs
|
# Store other configs
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
|||||||
@@ -22,45 +22,17 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
|
||||||
model_configs = []
|
|
||||||
if model_paths is not None:
|
|
||||||
model_paths = json.loads(model_paths)
|
|
||||||
model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths]
|
|
||||||
if model_id_with_origin_paths is not None:
|
|
||||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
|
||||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths]
|
|
||||||
|
|
||||||
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||||
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
|
||||||
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
|
||||||
|
|
||||||
# Enable FP8
|
# Training mode
|
||||||
if enable_fp8_training:
|
self.switch_pipe_to_training_mode(
|
||||||
self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn)
|
self, self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||||
# Reset training scheduler (do it in each training step)
|
enable_fp8_training=enable_fp8_training,
|
||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
)
|
||||||
|
|
||||||
# Freeze untrainable models
|
|
||||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
|
||||||
|
|
||||||
# Add LoRA to the base models
|
|
||||||
if lora_base_model is not None:
|
|
||||||
model = self.add_lora_to_model(
|
|
||||||
getattr(self.pipe, lora_base_model),
|
|
||||||
target_modules=lora_target_modules.split(","),
|
|
||||||
lora_rank=lora_rank,
|
|
||||||
upcast_dtype=self.pipe.torch_dtype,
|
|
||||||
)
|
|
||||||
if lora_checkpoint is not None:
|
|
||||||
state_dict = load_state_dict(lora_checkpoint)
|
|
||||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
|
||||||
load_result = model.load_state_dict(state_dict, strict=False)
|
|
||||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
|
||||||
if len(load_result[1]) > 0:
|
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
|
||||||
setattr(self.pipe, lora_base_model, model)
|
|
||||||
|
|
||||||
# Store other configs
|
# Store other configs
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
|||||||
@@ -21,36 +21,15 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
model_configs = []
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
|
||||||
if model_paths is not None:
|
|
||||||
model_paths = json.loads(model_paths)
|
|
||||||
model_configs += [ModelConfig(path=path) for path in model_paths]
|
|
||||||
if model_id_with_origin_paths is not None:
|
|
||||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
|
||||||
model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths]
|
|
||||||
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
|
||||||
|
|
||||||
# Reset training scheduler
|
# Training mode
|
||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
self.switch_pipe_to_training_mode(
|
||||||
|
self, self.pipe, trainable_models,
|
||||||
# Freeze untrainable models
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
|
||||||
self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
enable_fp8_training=False,
|
||||||
|
)
|
||||||
# Add LoRA to the base models
|
|
||||||
if lora_base_model is not None:
|
|
||||||
model = self.add_lora_to_model(
|
|
||||||
getattr(self.pipe, lora_base_model),
|
|
||||||
target_modules=lora_target_modules.split(","),
|
|
||||||
lora_rank=lora_rank
|
|
||||||
)
|
|
||||||
if lora_checkpoint is not None:
|
|
||||||
state_dict = load_state_dict(lora_checkpoint)
|
|
||||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
|
||||||
load_result = model.load_state_dict(state_dict, strict=False)
|
|
||||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
|
||||||
if len(load_result[1]) > 0:
|
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
|
||||||
setattr(self.pipe, lora_base_model, model)
|
|
||||||
|
|
||||||
# Store other configs
|
# Store other configs
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
|||||||
Reference in New Issue
Block a user