mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
rebuild base modules
This commit is contained in:
@@ -91,7 +91,7 @@ class SDXLUNet(torch.nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
time_embeds = self.add_time_proj(add_time_id)
|
||||
@@ -133,7 +133,8 @@ class SDXLUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLUNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -197,7 +198,10 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
@@ -1889,4 +1893,7 @@ class SDXLUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
if "text_intermediate_proj.weight" in state_dict_:
|
||||
return state_dict_, {"is_kolors": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
Reference in New Issue
Block a user