mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
update TileWorker for better visual quality
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch, math
|
||||
from .attention import Attention
|
||||
from .tiler import Tiler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class Timesteps(torch.nn.Module):
|
||||
@@ -145,7 +145,13 @@ class AttentionBlock(torch.nn.Module):
|
||||
if need_proj_out:
|
||||
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, cross_frame_attention=False, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
cross_frame_attention=False,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
**kwargs
|
||||
):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
@@ -159,11 +165,32 @@ class AttentionBlock(torch.nn.Module):
|
||||
encoder_hidden_states = text_emb.mean(dim=0, keepdim=True)
|
||||
else:
|
||||
encoder_hidden_states = text_emb
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
)
|
||||
|
||||
if tiled:
|
||||
tile_size = min(tile_size, min(height, width))
|
||||
hidden_states = hidden_states.permute(0, 2, 1).reshape(batch, inner_dim, height, width)
|
||||
def block_tile_forward(x):
|
||||
b, c, h, w = x.shape
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
|
||||
x = block(x, encoder_hidden_states)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
return x
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
block_tile_forward,
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
)
|
||||
if cross_frame_attention:
|
||||
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user