This commit is contained in:
Artiprocher
2023-12-30 21:01:24 +08:00
parent b9771db163
commit d24ddaacaa
19 changed files with 2252 additions and 34 deletions

View File

@@ -73,7 +73,7 @@ class DownSampler(torch.nn.Module):
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
self.extra_padding = extra_padding
def forward(self, hidden_states, time_emb, text_emb, res_stack):
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
if self.extra_padding:
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
@@ -85,7 +85,7 @@ class UpSampler(torch.nn.Module):
super().__init__()
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, hidden_states, time_emb, text_emb, res_stack):
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = self.conv(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
@@ -105,7 +105,7 @@ class ResnetBlock(torch.nn.Module):
if in_channels != out_channels:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, hidden_states, time_emb, text_emb, res_stack):
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
x = hidden_states
x = self.norm1(x)
x = self.nonlinearity(x)
@@ -125,7 +125,7 @@ class ResnetBlock(torch.nn.Module):
class AttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5, need_proj_out=True):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@@ -141,10 +141,11 @@ class AttentionBlock(torch.nn.Module):
)
for d in range(num_layers)
])
self.need_proj_out = need_proj_out
if need_proj_out:
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, time_emb, text_emb, res_stack):
def forward(self, hidden_states, time_emb, text_emb, res_stack, cross_frame_attention=False, **kwargs):
batch, _, height, width = hidden_states.shape
residual = hidden_states
@@ -153,15 +154,25 @@ class AttentionBlock(torch.nn.Module):
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
if cross_frame_attention:
hidden_states = hidden_states.reshape(1, batch * height * width, inner_dim)
encoder_hidden_states = text_emb.mean(dim=0, keepdim=True)
else:
encoder_hidden_states = text_emb
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=text_emb
encoder_hidden_states=encoder_hidden_states
)
if cross_frame_attention:
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
if self.need_proj_out:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
else:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
return hidden_states, time_emb, text_emb, res_stack
@@ -170,7 +181,7 @@ class PushBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, time_emb, text_emb, res_stack):
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
res_stack.append(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
@@ -179,7 +190,7 @@ class PopBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, time_emb, text_emb, res_stack):
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
res_hidden_states = res_stack.pop()
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
return hidden_states, time_emb, text_emb, res_stack