mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
remove unnecessary params in cache
This commit is contained in:
@@ -1,9 +1,32 @@
|
||||
import torch, json, os
|
||||
import torch, json, os, inspect
|
||||
from ..core import ModelConfig, load_state_dict
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from .base_pipeline import PipelineUnit
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
|
||||
|
||||
class GeneralUnit_RemoveCache(PipelineUnit):
|
||||
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
|
||||
super().__init__(take_over=True)
|
||||
self.required_params = required_params
|
||||
self.force_remove_params_shared = force_remove_params_shared
|
||||
self.force_remove_params_posi = force_remove_params_posi
|
||||
self.force_remove_params_nega = force_remove_params_nega
|
||||
|
||||
def process_params(self, inputs, required_params, force_remove_params):
|
||||
inputs_ = {}
|
||||
for name, param in inputs.items():
|
||||
if name in required_params and name not in force_remove_params:
|
||||
inputs_[name] = param
|
||||
return inputs_
|
||||
|
||||
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
||||
inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
|
||||
inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
|
||||
inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class DiffusionTrainingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -231,14 +254,30 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
setattr(pipe, lora_base_model, model)
|
||||
|
||||
|
||||
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
||||
def split_pipeline_units(
|
||||
self, task, pipe,
|
||||
trainable_models=None, lora_base_model=None,
|
||||
# TODO: set `remove_unnecessary_params` to `True` by default
|
||||
remove_unnecessary_params=False,
|
||||
# TODO: move `loss_required_params` to `loss.py`
|
||||
loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
|
||||
force_remove_params_shared=tuple(),
|
||||
force_remove_params_posi=tuple(),
|
||||
force_remove_params_nega=tuple(),
|
||||
):
|
||||
models_require_backward = []
|
||||
if trainable_models is not None:
|
||||
models_require_backward += trainable_models.split(",")
|
||||
if lora_base_model is not None:
|
||||
models_require_backward += [lora_base_model]
|
||||
if task.endswith(":data_process"):
|
||||
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||
other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||
if remove_unnecessary_params:
|
||||
required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
|
||||
for unit in other_units:
|
||||
required_params.extend(unit.fetch_input_params())
|
||||
required_params = sorted(list(set(required_params)))
|
||||
pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
|
||||
elif task.endswith(":train"):
|
||||
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||
return pipe
|
||||
|
||||
@@ -32,7 +32,12 @@ class LTX2TrainingModule(DiffusionTrainingModule):
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
self.pipe = self.split_pipeline_units(
|
||||
task, self.pipe, trainable_models, lora_base_model,
|
||||
remove_unnecessary_params=True,
|
||||
force_remove_params_shared=("audio_latents", "video_latents"),
|
||||
force_remove_params_nega=("audio_context", "video_context")
|
||||
)
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
|
||||
Reference in New Issue
Block a user