Files
DiffSynth-Studio/diffsynth/diffusion/training_module.py
2026-03-17 13:34:25 +08:00

348 lines
16 KiB
Python

import torch, json, os, inspect
from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput
from .base_pipeline import PipelineUnit
from peft import LoraConfig, inject_adapter_in_model
class GeneralUnit_RemoveCache(PipelineUnit):
# Only used for training
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
super().__init__(take_over=True)
self.required_params = required_params
self.force_remove_params_shared = force_remove_params_shared
self.force_remove_params_posi = force_remove_params_posi
self.force_remove_params_nega = force_remove_params_nega
def process_params(self, inputs, required_params, force_remove_params):
inputs_ = {}
for name, param in inputs.items():
if name in required_params and name not in force_remove_params:
inputs_[name] = param
return inputs_
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
return inputs_shared, inputs_posi, inputs_nega
class GeneralUnit_SkillProcessInputs(PipelineUnit):
# Only used for training
def __init__(self, data_processor):
super().__init__(
input_params=("skill_inputs",),
output_params=("skill_inputs",),
)
self.data_processor = data_processor
def process(self, pipe, skill_inputs):
if not hasattr(pipe, "skill_model"):
return {}
if self.data_processor is not None:
skill_inputs = self.data_processor(**skill_inputs)
skill_inputs = pipe.skill_model.process_inputs(pipe=pipe, **skill_inputs)
return {"skill_inputs": skill_inputs}
class GeneralUnit_SkillForward(PipelineUnit):
# Only used for training
def __init__(self):
super().__init__(
input_params=("skill_inputs",),
output_params=("skill_cache",),
onload_model_names=("skill_model",)
)
def process(self, pipe, skill_inputs):
if not hasattr(pipe, "skill_model"):
return {}
skill_cache = pipe.skill_model.forward(**skill_inputs)
return {"skill_cache": skill_cache}
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
def to(self, *args, **kwargs):
for name, model in self.named_children():
model.to(*args, **kwargs)
return self
def trainable_modules(self):
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
return trainable_modules
def trainable_param_names(self):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
if lora_alpha is None:
lora_alpha = lora_rank
if isinstance(target_modules, list) and len(target_modules) == 1:
target_modules = target_modules[0]
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
if upcast_dtype is not None:
for param in model.parameters():
if param.requires_grad:
param.data = param.to(upcast_dtype)
return model
def mapping_lora_state_dict(self, state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if "lora_A.weight" in key or "lora_B.weight" in key:
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
new_state_dict[new_key] = value
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
new_state_dict[key] = value
return new_state_dict
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
trainable_param_names = self.trainable_param_names()
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
if remove_prefix is not None:
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith(remove_prefix):
name = name[len(remove_prefix):]
state_dict_[name] = param
state_dict = state_dict_
return state_dict
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
if data is None:
return data
elif isinstance(data, torch.Tensor):
data = data.to(device)
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
data = data.to(torch_float_dtype)
return data
elif isinstance(data, tuple):
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, list):
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, dict):
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
return data
else:
return data
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
if fp8:
return {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": device,
"onload_dtype": torch.float8_e4m3fn,
"onload_device": device,
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
}
elif offload:
return {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
"clear_parameters": True,
}
else:
return {}
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
fp8_models = [] if fp8_models is None else fp8_models.split(",")
offload_models = [] if offload_models is None else offload_models.split(",")
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
for path in model_paths:
vram_config = self.parse_vram_config(
fp8=path in fp8_models,
offload=path in offload_models,
device=device
)
model_configs.append(ModelConfig(path=path, **vram_config))
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for model_id_with_origin_path in model_id_with_origin_paths:
vram_config = self.parse_vram_config(
fp8=model_id_with_origin_path in fp8_models,
offload=model_id_with_origin_path in offload_models,
device=device
)
config = self.parse_path_or_model_id(model_id_with_origin_path)
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
return model_configs
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
if model_id_with_origin_path is None:
return default_value
elif os.path.exists(model_id_with_origin_path):
return ModelConfig(path=model_id_with_origin_path)
else:
if ":" not in model_id_with_origin_path:
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
split_id = model_id_with_origin_path.rfind(":")
model_id = model_id_with_origin_path[:split_id]
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
def auto_detect_lora_target_modules(
self,
model: torch.nn.Module,
search_for_linear=False,
linear_detector=lambda x: min(x.weight.shape) >= 512,
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
name_prefix="",
):
lora_target_modules = []
if search_for_linear:
for name, module in model.named_modules():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
if isinstance(module, torch.nn.Linear) and linear_detector(module):
lora_target_modules.append(module_name)
else:
for name, module in model.named_children():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
lora_target_modules += self.auto_detect_lora_target_modules(
module,
search_for_linear=block_list_detector(module),
linear_detector=linear_detector,
block_list_detector=block_list_detector,
name_prefix=module_name,
)
return lora_target_modules
def parse_lora_target_modules(self, model, lora_target_modules):
if lora_target_modules == "":
print("No LoRA target modules specified. The framework will automatically search for them.")
lora_target_modules = self.auto_detect_lora_target_modules(model)
print(f"LoRA will be patched at {lora_target_modules}.")
else:
lora_target_modules = lora_target_modules.split(",")
return lora_target_modules
def load_training_skill_model(self, pipe, path_or_model_id):
if path_or_model_id is None:
return pipe
model_config = self.parse_path_or_model_id(path_or_model_id)
pipe.load_training_skill_model(model_config)
pipe.units.append(GeneralUnit_SkillProcessInputs(pipe.skill_data_processor))
pipe.units.append(GeneralUnit_SkillForward())
return pipe
def switch_pipe_to_training_mode(
self,
pipe,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
task="sft",
):
# Scheduler
pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Preset LoRA
if preset_lora_path is not None:
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
# FP8
# FP8 relies on a model-specific memory management scheme.
# It is delegated to the subclass.
# Add LoRA to the base models
if lora_base_model is not None and not task.endswith(":data_process"):
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
return
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
def split_pipeline_units(
self, task, pipe,
trainable_models=None, lora_base_model=None,
# TODO: set `remove_unnecessary_params` to `True` by default
remove_unnecessary_params=False,
# TODO: move `loss_required_params` to `loss.py`
loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
force_remove_params_shared=tuple(),
force_remove_params_posi=tuple(),
force_remove_params_nega=tuple(),
):
models_require_backward = []
if trainable_models is not None:
models_require_backward += trainable_models.split(",")
if lora_base_model is not None:
models_require_backward += [lora_base_model]
if task.endswith(":data_process"):
other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
if remove_unnecessary_params:
required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
for unit in other_units:
required_params.extend(unit.fetch_input_params())
required_params = sorted(list(set(required_params)))
pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
elif task.endswith(":train"):
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
return pipe
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
controlnet_keys_map = (
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
("controlnet_", "controlnet_inputs"),
)
controlnet_inputs = {}
for extra_input in extra_inputs:
for prefix, name in controlnet_keys_map:
if extra_input.startswith(prefix):
if name not in controlnet_inputs:
controlnet_inputs[name] = {}
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
break
else:
inputs_shared[extra_input] = data[extra_input]
for name, params in controlnet_inputs.items():
inputs_shared[name] = [ControlNetInput(**params)]
return inputs_shared