add video_vae and dit for ltx-2

This commit is contained in:
mi804
2026-01-27 19:34:09 +08:00
parent ffb7a138f7
commit 00da4b6c4f
8 changed files with 3743 additions and 2 deletions

View File

@@ -0,0 +1,9 @@
def LTXModelStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("model.diffusion_model."):
new_name = name.replace("model.diffusion_model.", "")
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
continue
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,22 @@
def LTX2VideoEncoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.encoder."):
new_name = name.replace("vae.encoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_
def LTX2VideoDecoderStateDictConverter(state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith("vae.decoder."):
new_name = name.replace("vae.decoder.", "")
state_dict_[new_name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
state_dict_[new_name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,22 @@
import torch
from diffsynth.models.model_loader import ModelPool
from diffsynth.core.loader import ModelConfig
def test_model_loading(model_name,
model_config: ModelConfig,
vram_limit: float = None,
device="cpu",
torch_dtype=torch.bfloat16):
model_pool = ModelPool()
model_config.download_if_necessary()
vram_config = model_config.vram_config()
vram_config["computation_dtype"] = torch_dtype
vram_config["computation_device"] = device
model_pool.auto_load_model(
model_config.path,
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
)
return model_pool.fetch_model(model_name)