mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
from .sd3_dit import TimestepEmbeddings
|
|
from .flux_dit import RoPEEmbedding
|
|
import torch
|
|
from einops import repeat
|
|
|
|
|
|
class FluxReferenceEmbedder(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
|
self.idx_embedder = TimestepEmbeddings(256, 256)
|
|
self.proj = torch.nn.Linear(3072, 3072)
|
|
|
|
def forward(self, image_ids, idx, dtype, device):
|
|
pos_emb = self.pos_embedder(image_ids, device=device)
|
|
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
|
|
length = pos_emb.shape[2]
|
|
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
|
|
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
|
image_rotary_emb = pos_emb + idx_emb
|
|
return image_rotary_emb
|
|
|
|
def init(self):
|
|
self.idx_embedder.timestep_embedder[-1].load_state_dict({
|
|
"weight": torch.zeros((256, 256)),
|
|
"bias": torch.zeros((256,))
|
|
}),
|
|
self.proj.load_state_dict({
|
|
"weight": torch.eye(3072),
|
|
"bias": torch.zeros((3072,))
|
|
})
|