support HunyuanDiT

This commit is contained in:
Artiprocher
2024-06-05 13:00:39 +08:00
parent 83461d400c
commit 78f53a0754
23 changed files with 69988 additions and 185 deletions

View File

@@ -24,6 +24,9 @@ from .svd_vae_encoder import SVDVAEEncoder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
@@ -83,6 +86,22 @@ class ModelManager:
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
return param_name in state_dict
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_hunyuan_dit(self, state_dict):
param_name = "final_layer.adaLN_modulation.1.weight"
return param_name in state_dict
def is_diffusers_vae(self, state_dict):
param_name = "quant_conv.weight"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"image_encoder": SVDImageEncoder,
@@ -223,6 +242,45 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_clip_text_encoder"
model = HunyuanDiTCLIPTextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_t5_text_encoder"
model = HunyuanDiTT5TextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit(self, state_dict, file_path=""):
component = "hunyuan_dit"
model = HunyuanDiT()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_diffusers_vae(self, state_dict, file_path=""):
# TODO: detect SD and SDXL
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -276,6 +334,14 @@ class ModelManager:
self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict):
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit(state_dict):
self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list: