rebuild base modules

This commit is contained in:
Artiprocher
2024-07-26 12:15:40 +08:00
parent 9471bff8a4
commit e3f8a576cf
76 changed files with 3253 additions and 3563 deletions

View File

@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
# 2. pre-process
@@ -342,7 +342,8 @@ class SDUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDUNetStateDictConverter()