mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
v1.2
This commit is contained in:
@@ -73,7 +73,7 @@ class DownSampler(torch.nn.Module):
|
||||
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
|
||||
self.extra_padding = extra_padding
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
if self.extra_padding:
|
||||
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
@@ -85,7 +85,7 @@ class UpSampler(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
@@ -105,7 +105,7 @@ class ResnetBlock(torch.nn.Module):
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
x = hidden_states
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
@@ -125,7 +125,7 @@ class ResnetBlock(torch.nn.Module):
|
||||
|
||||
class AttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5, need_proj_out=True):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
@@ -141,10 +141,11 @@ class AttentionBlock(torch.nn.Module):
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
self.need_proj_out = need_proj_out
|
||||
if need_proj_out:
|
||||
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
||||
|
||||
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, cross_frame_attention=False, **kwargs):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
@@ -153,15 +154,25 @@ class AttentionBlock(torch.nn.Module):
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
if cross_frame_attention:
|
||||
hidden_states = hidden_states.reshape(1, batch * height * width, inner_dim)
|
||||
encoder_hidden_states = text_emb.mean(dim=0, keepdim=True)
|
||||
else:
|
||||
encoder_hidden_states = text_emb
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=text_emb
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
)
|
||||
if cross_frame_attention:
|
||||
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
if self.need_proj_out:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
else:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
@@ -170,7 +181,7 @@ class PushBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
res_stack.append(hidden_states)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
@@ -179,7 +190,7 @@ class PopBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
res_hidden_states = res_stack.pop()
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
Reference in New Issue
Block a user