mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
support reference image
This commit is contained in:
@@ -10,9 +10,9 @@ class FluxReferenceEmbedder(torch.nn.Module):
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.idx_embedder = TimestepEmbeddings(256, 256)
|
||||
|
||||
def forward(self, image_ids, idx, dtype):
|
||||
pos_emb = self.pos_embedder(image_ids)
|
||||
idx_emb = self.idx_embedder(idx, dtype=dtype)
|
||||
def forward(self, image_ids, idx, dtype, device):
|
||||
pos_emb = self.pos_embedder(image_ids, device=device)
|
||||
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
|
||||
length = pos_emb.shape[2]
|
||||
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
|
||||
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
||||
|
||||
Reference in New Issue
Block a user