remove unnecessary params in cache

This commit is contained in:
Artiprocher
2026-03-09 14:09:30 +08:00
parent a38954b72c
commit 13eff18e7d
2 changed files with 48 additions and 4 deletions

View File

@@ -32,7 +32,12 @@ class LTX2TrainingModule(DiffusionTrainingModule):
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
self.pipe = self.split_pipeline_units(
task, self.pipe, trainable_models, lora_base_model,
remove_unnecessary_params=True,
force_remove_params_shared=("audio_latents", "video_latents"),
force_remove_params_nega=("audio_context", "video_context")
)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,