rebuild base modules

This commit is contained in:
Artiprocher
2024-07-26 12:15:40 +08:00
parent 9471bff8a4
commit e3f8a576cf
76 changed files with 3253 additions and 3563 deletions

View File

@@ -19,7 +19,8 @@ class SD3TextEncoder1(SDTextEncoder):
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return pooled_embeds, hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder1StateDictConverter()
@@ -28,7 +29,8 @@ class SD3TextEncoder2(SDXLTextEncoder2):
def __init__(self):
super().__init__()
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder2StateDictConverter()
@@ -72,7 +74,8 @@ class SD3TextEncoder3(T5EncoderModel):
prompt_emb = outputs.last_hidden_state
return prompt_emb
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder3StateDictConverter()