mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add video_vae and dit for ltx-2
This commit is contained in:
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def LTXModelStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("model.diffusion_model."):
|
||||
new_name = name.replace("model.diffusion_model.", "")
|
||||
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
|
||||
continue
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
@@ -0,0 +1,22 @@
|
||||
def LTX2VideoEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.encoder."):
|
||||
new_name = name.replace("vae.encoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2VideoDecoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.decoder."):
|
||||
new_name = name.replace("vae.decoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
22
diffsynth/utils/test/load_model.py
Normal file
22
diffsynth/utils/test/load_model.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from diffsynth.models.model_loader import ModelPool
|
||||
from diffsynth.core.loader import ModelConfig
|
||||
|
||||
|
||||
def test_model_loading(model_name,
|
||||
model_config: ModelConfig,
|
||||
vram_limit: float = None,
|
||||
device="cpu",
|
||||
torch_dtype=torch.bfloat16):
|
||||
model_pool = ModelPool()
|
||||
model_config.download_if_necessary()
|
||||
vram_config = model_config.vram_config()
|
||||
vram_config["computation_dtype"] = torch_dtype
|
||||
vram_config["computation_device"] = device
|
||||
model_pool.auto_load_model(
|
||||
model_config.path,
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
)
|
||||
return model_pool.fetch_model(model_name)
|
||||
Reference in New Issue
Block a user