mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
fix base model
This commit is contained in:
@@ -9,6 +9,7 @@ class FluxReferenceEmbedder(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.idx_embedder = TimestepEmbeddings(256, 256)
|
||||
self.proj = torch.nn.Linear(3072, 3072)
|
||||
|
||||
def forward(self, image_ids, idx, dtype, device):
|
||||
pos_emb = self.pos_embedder(image_ids, device=device)
|
||||
@@ -18,3 +19,13 @@ class FluxReferenceEmbedder(torch.nn.Module):
|
||||
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
||||
image_rotary_emb = pos_emb + idx_emb
|
||||
return image_rotary_emb
|
||||
|
||||
def init(self):
|
||||
self.idx_embedder.timestep_embedder[-1].load_state_dict({
|
||||
"weight": torch.zeros((256, 256)),
|
||||
"bias": torch.zeros((256,))
|
||||
}),
|
||||
self.proj.load_state_dict({
|
||||
"weight": torch.eye(3072),
|
||||
"bias": torch.zeros((3072,))
|
||||
})
|
||||
|
||||
@@ -709,6 +709,7 @@ def lets_dance_flux(
|
||||
hidden_states_ref = dit.patchify(hidden_states_ref)
|
||||
hidden_states_ref = dit.x_embedder(hidden_states_ref)
|
||||
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
|
||||
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
|
||||
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
|
||||
|
||||
# TeaCache
|
||||
|
||||
Reference in New Issue
Block a user