diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index cc2e79d..0a00118 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -1,9 +1,32 @@ -import torch, json, os +import torch, json, os, inspect from ..core import ModelConfig, load_state_dict from ..utils.controlnet import ControlNetInput +from .base_pipeline import PipelineUnit from peft import LoraConfig, inject_adapter_in_model +class GeneralUnit_RemoveCache(PipelineUnit): + def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()): + super().__init__(take_over=True) + self.required_params = required_params + self.force_remove_params_shared = force_remove_params_shared + self.force_remove_params_posi = force_remove_params_posi + self.force_remove_params_nega = force_remove_params_nega + + def process_params(self, inputs, required_params, force_remove_params): + inputs_ = {} + for name, param in inputs.items(): + if name in required_params and name not in force_remove_params: + inputs_[name] = param + return inputs_ + + def process(self, pipe, inputs_shared, inputs_posi, inputs_nega): + inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared) + inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi) + inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega) + return inputs_shared, inputs_posi, inputs_nega + + class DiffusionTrainingModule(torch.nn.Module): def __init__(self): super().__init__() @@ -231,14 +254,30 @@ class DiffusionTrainingModule(torch.nn.Module): setattr(pipe, lora_base_model, model) - def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None): + def split_pipeline_units( + self, task, pipe, + trainable_models=None, lora_base_model=None, + # TODO: set `remove_unnecessary_params` to `True` by default + remove_unnecessary_params=False, + # TODO: move `loss_required_params` to `loss.py` + loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"), + force_remove_params_shared=tuple(), + force_remove_params_posi=tuple(), + force_remove_params_nega=tuple(), + ): models_require_backward = [] if trainable_models is not None: models_require_backward += trainable_models.split(",") if lora_base_model is not None: models_require_backward += [lora_base_model] if task.endswith(":data_process"): - _, pipe.units = pipe.split_pipeline_units(models_require_backward) + other_units, pipe.units = pipe.split_pipeline_units(models_require_backward) + if remove_unnecessary_params: + required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters] + for unit in other_units: + required_params.extend(unit.fetch_input_params()) + required_params = sorted(list(set(required_params))) + pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega)) elif task.endswith(":train"): pipe.units, _ = pipe.split_pipeline_units(models_require_backward) return pipe diff --git a/examples/ltx2/model_training/train.py b/examples/ltx2/model_training/train.py index 3eb023a..980d903 100644 --- a/examples/ltx2/model_training/train.py +++ b/examples/ltx2/model_training/train.py @@ -32,7 +32,12 @@ class LTX2TrainingModule(DiffusionTrainingModule): model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path) self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) - self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + self.pipe = self.split_pipeline_units( + task, self.pipe, trainable_models, lora_base_model, + remove_unnecessary_params=True, + force_remove_params_shared=("audio_latents", "video_latents"), + force_remove_params_nega=("audio_context", "video_context") + ) # Training mode self.switch_pipe_to_training_mode( self.pipe, trainable_models,