rebuild base modules

This commit is contained in:
Artiprocher
2024-07-26 12:15:40 +08:00
parent 9471bff8a4
commit e3f8a576cf
76 changed files with 3253 additions and 3563 deletions

View File

@@ -91,7 +91,7 @@ class SDXLUNet(torch.nn.Module):
**kwargs
):
# 1. time
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
t_emb = self.time_proj(timestep).to(sample.dtype)
t_emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(add_time_id)
@@ -133,7 +133,8 @@ class SDXLUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLUNetStateDictConverter()
@@ -197,7 +198,10 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
@@ -1889,4 +1893,7 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_