mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
load hunyuani2v model
This commit is contained in:
@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
@@ -555,7 +556,7 @@ class FinalLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
@@ -565,7 +566,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
@@ -610,7 +611,9 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||
if self.guidance_in is not None:
|
||||
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
@@ -783,6 +786,7 @@ class HunyuanVideoDiTStateDictConverter:
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
@@ -882,4 +886,6 @@ class HunyuanVideoDiTStateDictConverter:
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
if origin_hash_key == "ae3c22aaa28bfae6f3688f796c9814ae":
|
||||
return state_dict_, {"in_channels": 33, "guidance_embed":False}
|
||||
return state_dict_
|
||||
|
||||
Reference in New Issue
Block a user