mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
ipadapter for sdxl
This commit is contained in:
@@ -47,15 +47,15 @@ class BasicTransformerBlock(torch.nn.Module):
|
||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states):
|
||||
def forward(self, hidden_states, encoder_hidden_states, ipadapter_kwargs=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None,)
|
||||
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, ipadapter_kwargs=ipadapter_kwargs)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
@@ -150,6 +150,7 @@ class AttentionBlock(torch.nn.Module):
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
cross_frame_attention=False,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
ipadapter_kwargs_list={},
|
||||
**kwargs
|
||||
):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
@@ -188,10 +189,11 @@ class AttentionBlock(torch.nn.Module):
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
for block_id, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
ipadapter_kwargs=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
if cross_frame_attention:
|
||||
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)
|
||||
|
||||
Reference in New Issue
Block a user