remove unnecessary params in cache

This commit is contained in:
Artiprocher
2026-03-09 14:09:30 +08:00
parent a38954b72c
commit 13eff18e7d
2 changed files with 48 additions and 4 deletions

View File

@@ -1,9 +1,32 @@
import torch, json, os
import torch, json, os, inspect
from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput
from .base_pipeline import PipelineUnit
from peft import LoraConfig, inject_adapter_in_model
class GeneralUnit_RemoveCache(PipelineUnit):
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
super().__init__(take_over=True)
self.required_params = required_params
self.force_remove_params_shared = force_remove_params_shared
self.force_remove_params_posi = force_remove_params_posi
self.force_remove_params_nega = force_remove_params_nega
def process_params(self, inputs, required_params, force_remove_params):
inputs_ = {}
for name, param in inputs.items():
if name in required_params and name not in force_remove_params:
inputs_[name] = param
return inputs_
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
return inputs_shared, inputs_posi, inputs_nega
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -231,14 +254,30 @@ class DiffusionTrainingModule(torch.nn.Module):
setattr(pipe, lora_base_model, model)
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
def split_pipeline_units(
self, task, pipe,
trainable_models=None, lora_base_model=None,
# TODO: set `remove_unnecessary_params` to `True` by default
remove_unnecessary_params=False,
# TODO: move `loss_required_params` to `loss.py`
loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
force_remove_params_shared=tuple(),
force_remove_params_posi=tuple(),
force_remove_params_nega=tuple(),
):
models_require_backward = []
if trainable_models is not None:
models_require_backward += trainable_models.split(",")
if lora_base_model is not None:
models_require_backward += [lora_base_model]
if task.endswith(":data_process"):
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
if remove_unnecessary_params:
required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
for unit in other_units:
required_params.extend(unit.fetch_input_params())
required_params = sorted(list(set(required_params)))
pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
elif task.endswith(":train"):
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
return pipe

View File

@@ -32,7 +32,12 @@ class LTX2TrainingModule(DiffusionTrainingModule):
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
self.pipe = self.split_pipeline_units(
task, self.pipe, trainable_models, lora_base_model,
remove_unnecessary_params=True,
force_remove_params_shared=("audio_latents", "video_latents"),
force_remove_params_nega=("audio_context", "video_context")
)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,