ipadapter for sdxl

This commit is contained in:
Artiprocher
2024-05-14 23:24:24 +08:00
parent 3b5bbb5773
commit 83461d400c
8 changed files with 251 additions and 27 deletions

View File

@@ -47,15 +47,15 @@ class BasicTransformerBlock(torch.nn.Module):
self.ff = torch.nn.Linear(dim * 4, dim)
def forward(self, hidden_states, encoder_hidden_states):
def forward(self, hidden_states, encoder_hidden_states, ipadapter_kwargs=None):
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None,)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, ipadapter_kwargs=ipadapter_kwargs)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
@@ -150,6 +150,7 @@ class AttentionBlock(torch.nn.Module):
hidden_states, time_emb, text_emb, res_stack,
cross_frame_attention=False,
tiled=False, tile_size=64, tile_stride=32,
ipadapter_kwargs_list={},
**kwargs
):
batch, _, height, width = hidden_states.shape
@@ -188,10 +189,11 @@ class AttentionBlock(torch.nn.Module):
)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
for block in self.transformer_blocks:
for block_id, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states
encoder_hidden_states=encoder_hidden_states,
ipadapter_kwargs=ipadapter_kwargs_list.get(block_id, None)
)
if cross_frame_attention:
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)