support reference image

This commit is contained in:
Artiprocher
2025-04-11 16:14:24 +08:00
parent 9e78bf5e89
commit a572254a1d
4 changed files with 15 additions and 9 deletions

View File

@@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module):
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = scale.to(device)
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
@@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module):
return out.float()
def forward(self, ids):
def forward(self, ids, device="cpu"):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)

View File

@@ -10,9 +10,9 @@ class FluxReferenceEmbedder(torch.nn.Module):
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.idx_embedder = TimestepEmbeddings(256, 256)
def forward(self, image_ids, idx, dtype):
pos_emb = self.pos_embedder(image_ids)
idx_emb = self.idx_embedder(idx, dtype=dtype)
def forward(self, image_ids, idx, dtype, device):
pos_emb = self.pos_embedder(image_ids, device=device)
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
length = pos_emb.shape[2]
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)

View File

@@ -694,7 +694,7 @@ def lets_dance_flux(
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
attention_mask = None
# Reference images