mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
remove unnecessary params in cache
This commit is contained in:
@@ -1,9 +1,32 @@
|
|||||||
import torch, json, os
|
import torch, json, os, inspect
|
||||||
from ..core import ModelConfig, load_state_dict
|
from ..core import ModelConfig, load_state_dict
|
||||||
from ..utils.controlnet import ControlNetInput
|
from ..utils.controlnet import ControlNetInput
|
||||||
|
from .base_pipeline import PipelineUnit
|
||||||
from peft import LoraConfig, inject_adapter_in_model
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralUnit_RemoveCache(PipelineUnit):
|
||||||
|
def __init__(self, required_params=tuple(), force_remove_params_shared=tuple(), force_remove_params_posi=tuple(), force_remove_params_nega=tuple()):
|
||||||
|
super().__init__(take_over=True)
|
||||||
|
self.required_params = required_params
|
||||||
|
self.force_remove_params_shared = force_remove_params_shared
|
||||||
|
self.force_remove_params_posi = force_remove_params_posi
|
||||||
|
self.force_remove_params_nega = force_remove_params_nega
|
||||||
|
|
||||||
|
def process_params(self, inputs, required_params, force_remove_params):
|
||||||
|
inputs_ = {}
|
||||||
|
for name, param in inputs.items():
|
||||||
|
if name in required_params and name not in force_remove_params:
|
||||||
|
inputs_[name] = param
|
||||||
|
return inputs_
|
||||||
|
|
||||||
|
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
inputs_shared = self.process_params(inputs_shared, self.required_params, self.force_remove_params_shared)
|
||||||
|
inputs_posi = self.process_params(inputs_posi, self.required_params, self.force_remove_params_posi)
|
||||||
|
inputs_nega = self.process_params(inputs_nega, self.required_params, self.force_remove_params_nega)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
class DiffusionTrainingModule(torch.nn.Module):
|
class DiffusionTrainingModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -231,14 +254,30 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
setattr(pipe, lora_base_model, model)
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
|
||||||
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
def split_pipeline_units(
|
||||||
|
self, task, pipe,
|
||||||
|
trainable_models=None, lora_base_model=None,
|
||||||
|
# TODO: set `remove_unnecessary_params` to `True` by default
|
||||||
|
remove_unnecessary_params=False,
|
||||||
|
# TODO: move `loss_required_params` to `loss.py`
|
||||||
|
loss_required_params=("input_latents", "max_timestep_boundary", "min_timestep_boundary", "first_frame_latents", "video_latents", "audio_input_latents", "num_inference_steps"),
|
||||||
|
force_remove_params_shared=tuple(),
|
||||||
|
force_remove_params_posi=tuple(),
|
||||||
|
force_remove_params_nega=tuple(),
|
||||||
|
):
|
||||||
models_require_backward = []
|
models_require_backward = []
|
||||||
if trainable_models is not None:
|
if trainable_models is not None:
|
||||||
models_require_backward += trainable_models.split(",")
|
models_require_backward += trainable_models.split(",")
|
||||||
if lora_base_model is not None:
|
if lora_base_model is not None:
|
||||||
models_require_backward += [lora_base_model]
|
models_require_backward += [lora_base_model]
|
||||||
if task.endswith(":data_process"):
|
if task.endswith(":data_process"):
|
||||||
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
other_units, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||||
|
if remove_unnecessary_params:
|
||||||
|
required_params = list(loss_required_params) + [i for i in inspect.signature(self.pipe.model_fn).parameters]
|
||||||
|
for unit in other_units:
|
||||||
|
required_params.extend(unit.fetch_input_params())
|
||||||
|
required_params = sorted(list(set(required_params)))
|
||||||
|
pipe.units.append(GeneralUnit_RemoveCache(required_params, force_remove_params_shared, force_remove_params_posi, force_remove_params_nega))
|
||||||
elif task.endswith(":train"):
|
elif task.endswith(":train"):
|
||||||
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||||
return pipe
|
return pipe
|
||||||
|
|||||||
@@ -32,7 +32,12 @@ class LTX2TrainingModule(DiffusionTrainingModule):
|
|||||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||||
self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
self.pipe = self.split_pipeline_units(
|
||||||
|
task, self.pipe, trainable_models, lora_base_model,
|
||||||
|
remove_unnecessary_params=True,
|
||||||
|
force_remove_params_shared=("audio_latents", "video_latents"),
|
||||||
|
force_remove_params_nega=("audio_context", "video_context")
|
||||||
|
)
|
||||||
# Training mode
|
# Training mode
|
||||||
self.switch_pipe_to_training_mode(
|
self.switch_pipe_to_training_mode(
|
||||||
self.pipe, trainable_models,
|
self.pipe, trainable_models,
|
||||||
|
|||||||
Reference in New Issue
Block a user