update TileWorker for better visual quality

This commit is contained in:
Artiprocher
2024-01-09 22:29:17 +08:00
parent 552355692d
commit 8a460497fa
4 changed files with 47 additions and 122 deletions

View File

@@ -1,6 +1,6 @@
import torch, math
from .attention import Attention
from .tiler import Tiler
from .tiler import TileWorker
class Timesteps(torch.nn.Module):
@@ -145,7 +145,13 @@ class AttentionBlock(torch.nn.Module):
if need_proj_out:
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, time_emb, text_emb, res_stack, cross_frame_attention=False, **kwargs):
def forward(
self,
hidden_states, time_emb, text_emb, res_stack,
cross_frame_attention=False,
tiled=False, tile_size=64, tile_stride=32,
**kwargs
):
batch, _, height, width = hidden_states.shape
residual = hidden_states
@@ -159,11 +165,32 @@ class AttentionBlock(torch.nn.Module):
encoder_hidden_states = text_emb.mean(dim=0, keepdim=True)
else:
encoder_hidden_states = text_emb
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states
)
if tiled:
tile_size = min(tile_size, min(height, width))
hidden_states = hidden_states.permute(0, 2, 1).reshape(batch, inner_dim, height, width)
def block_tile_forward(x):
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
x = block(x, encoder_hidden_states)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
return x
for block in self.transformer_blocks:
hidden_states = TileWorker().tiled_forward(
block_tile_forward,
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states
)
if cross_frame_attention:
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)

View File

@@ -1,6 +1,5 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler
from .tiler import Tiler
class SDXLUNet(torch.nn.Module):
@@ -108,13 +107,10 @@ class SDXLUNet(torch.nn.Module):
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled:
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
block, hidden_states, time_emb, text_emb, res_stack,
height, width, tile_size, tile_stride
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
# 4. output
hidden_states = self.conv_norm_out(hidden_states)
@@ -123,23 +119,6 @@ class SDXLUNet(torch.nn.Module):
return hidden_states
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
resize_scale = inter_height / height
hidden_states = Tiler()(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
inter_device=hidden_states.device,
inter_dtype=hidden_states.dtype
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
return hidden_states, time_emb, text_emb, res_stack
def state_dict_converter(self):
return SDXLUNetStateDictConverter()

View File

@@ -2,76 +2,6 @@ import torch
from einops import rearrange, repeat
class Tiler(torch.nn.Module):
def __init__(self):
super().__init__()
def mask(self, height, width, line_width):
x = torch.arange(height).repeat(width, 1).T
y = torch.arange(width).repeat(height, 1)
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
mask = (mask / line_width).clip(0, 1)
return mask
def forward(self, forward_fn, x, tile_size, tile_stride, batch_size=1, inter_device="cpu", inter_dtype=torch.float32):
# Prepare
device = x.device
torch_dtype = x.dtype
# tile
b, c_in, h_in, w_in = x.shape
x = x.to(device=inter_device, dtype=inter_dtype)
fold_params = {
"kernel_size": (tile_size, tile_size),
"stride": (tile_stride, tile_stride)
}
unfold_operator = torch.nn.Unfold(**fold_params)
x = unfold_operator(x)
x = x.view((b, c_in, tile_size, tile_size, -1))
# inference
x_out_stack = []
for tile_id in range(0, x.shape[-1], batch_size):
# process input
next_tile_id = min(tile_id + batch_size, x.shape[-1])
x_in = x[:, :, :, :, tile_id: next_tile_id]
x_in = x_in.to(device=device, dtype=torch_dtype)
x_in = x_in.permute(4, 0, 1, 2, 3)
x_in = x_in.view((x_in.shape[0]*x_in.shape[1], x_in.shape[2], x_in.shape[3], x_in.shape[4]))
# process output
x_out = forward_fn(x_in)
x_out = x_out.view((next_tile_id - tile_id, b, x_out.shape[1], x_out.shape[2], x_out.shape[3]))
x_out = x_out.permute(1, 2, 3, 4, 0)
x_out = x_out.to(device=inter_device, dtype=inter_dtype)
x_out_stack.append(x_out)
x = torch.concat(x_out_stack, dim=-1)
# untile
in2out_scale = x.shape[2] / tile_size
h_out, w_out = int(h_in * in2out_scale), int(w_in * in2out_scale)
mask = self.mask(int(tile_size * in2out_scale), int(tile_size * in2out_scale), int(tile_stride * in2out_scale * 0.5))
mask = mask.to(device=inter_device, dtype=inter_dtype)
mask = mask.reshape((1, 1, mask.shape[0], mask.shape[1], 1))
x = x * mask
fold_params = {
"kernel_size": (int(tile_size * in2out_scale), int(tile_size * in2out_scale)),
"stride": (int(tile_stride * in2out_scale), int(tile_stride * in2out_scale))
}
fold_operator = torch.nn.Fold(output_size=(h_out, w_out), **fold_params)
divisor = fold_operator(mask.repeat(1, 1, 1, 1, x.shape[-1]).view(b, -1, x.shape[-1]))
x = x.view((b, -1, x.shape[-1]))
x = fold_operator(x) / divisor
x = x.to(device=device, dtype=torch_dtype)
return x
class TileWorker:
def __init__(self):
pass

View File

@@ -1,6 +1,6 @@
import torch
from ..models import SDUNet, SDMotionModel
from ..models.sd_unet import PushBlock, PopBlock
from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock
from ..models.tiler import TileWorker
from ..controlnets import MultiControlNetManager
@@ -75,25 +75,14 @@ def lets_dance(
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
if tiled:
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0],
hidden_states_input[batch_id: batch_id_],
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention
)
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff