fix base model

This commit is contained in:
Artiprocher
2025-04-14 13:11:45 +08:00
parent 7a06a58f49
commit 2a5355b7cb
4 changed files with 47 additions and 9 deletions

View File

@@ -9,6 +9,7 @@ class FluxReferenceEmbedder(torch.nn.Module):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.idx_embedder = TimestepEmbeddings(256, 256)
self.proj = torch.nn.Linear(3072, 3072)
def forward(self, image_ids, idx, dtype, device):
pos_emb = self.pos_embedder(image_ids, device=device)
@@ -18,3 +19,13 @@ class FluxReferenceEmbedder(torch.nn.Module):
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
image_rotary_emb = pos_emb + idx_emb
return image_rotary_emb
def init(self):
self.idx_embedder.timestep_embedder[-1].load_state_dict({
"weight": torch.zeros((256, 256)),
"bias": torch.zeros((256,))
}),
self.proj.load_state_dict({
"weight": torch.eye(3072),
"bias": torch.zeros((3072,))
})