diff --git a/diffsynth/models/sd_unet.py b/diffsynth/models/sd_unet.py index a0d937e..3f12a22 100644 --- a/diffsynth/models/sd_unet.py +++ b/diffsynth/models/sd_unet.py @@ -1,6 +1,6 @@ import torch, math from .attention import Attention -from .tiler import Tiler +from .tiler import TileWorker class Timesteps(torch.nn.Module): @@ -145,7 +145,13 @@ class AttentionBlock(torch.nn.Module): if need_proj_out: self.proj_out = torch.nn.Linear(inner_dim, in_channels) - def forward(self, hidden_states, time_emb, text_emb, res_stack, cross_frame_attention=False, **kwargs): + def forward( + self, + hidden_states, time_emb, text_emb, res_stack, + cross_frame_attention=False, + tiled=False, tile_size=64, tile_stride=32, + **kwargs + ): batch, _, height, width = hidden_states.shape residual = hidden_states @@ -159,11 +165,32 @@ class AttentionBlock(torch.nn.Module): encoder_hidden_states = text_emb.mean(dim=0, keepdim=True) else: encoder_hidden_states = text_emb - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states - ) + + if tiled: + tile_size = min(tile_size, min(height, width)) + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch, inner_dim, height, width) + def block_tile_forward(x): + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1).reshape(b, h*w, c) + x = block(x, encoder_hidden_states) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + return x + for block in self.transformer_blocks: + hidden_states = TileWorker().tiled_forward( + block_tile_forward, + hidden_states, + tile_size, + tile_stride, + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states + ) if cross_frame_attention: hidden_states = hidden_states.reshape(batch, height * width, inner_dim) diff --git a/diffsynth/models/sdxl_unet.py b/diffsynth/models/sdxl_unet.py index 51a7ce6..a336259 100644 --- a/diffsynth/models/sdxl_unet.py +++ b/diffsynth/models/sdxl_unet.py @@ -1,6 +1,5 @@ import torch from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler -from .tiler import Tiler class SDXLUNet(torch.nn.Module): @@ -108,13 +107,10 @@ class SDXLUNet(torch.nn.Module): # 3. blocks for i, block in enumerate(self.blocks): - if tiled: - hidden_states, time_emb, text_emb, res_stack = self.tiled_inference( - block, hidden_states, time_emb, text_emb, res_stack, - height, width, tile_size, tile_stride - ) - else: - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + hidden_states, time_emb, text_emb, res_stack = block( + hidden_states, time_emb, text_emb, res_stack, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) # 4. output hidden_states = self.conv_norm_out(hidden_states) @@ -123,23 +119,6 @@ class SDXLUNet(torch.nn.Module): return hidden_states - def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride): - if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]: - batch_size, inter_channel, inter_height, inter_width = hidden_states.shape - resize_scale = inter_height / height - - hidden_states = Tiler()( - lambda x: block(x, time_emb, text_emb, res_stack)[0], - hidden_states, - int(tile_size * resize_scale), - int(tile_stride * resize_scale), - inter_device=hidden_states.device, - inter_dtype=hidden_states.dtype - ) - else: - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) - return hidden_states, time_emb, text_emb, res_stack - def state_dict_converter(self): return SDXLUNetStateDictConverter() diff --git a/diffsynth/models/tiler.py b/diffsynth/models/tiler.py index 30db58f..af37ff6 100644 --- a/diffsynth/models/tiler.py +++ b/diffsynth/models/tiler.py @@ -2,76 +2,6 @@ import torch from einops import rearrange, repeat -class Tiler(torch.nn.Module): - def __init__(self): - super().__init__() - - def mask(self, height, width, line_width): - x = torch.arange(height).repeat(width, 1).T - y = torch.arange(width).repeat(height, 1) - mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values - mask = (mask / line_width).clip(0, 1) - return mask - - def forward(self, forward_fn, x, tile_size, tile_stride, batch_size=1, inter_device="cpu", inter_dtype=torch.float32): - # Prepare - device = x.device - torch_dtype = x.dtype - - # tile - b, c_in, h_in, w_in = x.shape - x = x.to(device=inter_device, dtype=inter_dtype) - fold_params = { - "kernel_size": (tile_size, tile_size), - "stride": (tile_stride, tile_stride) - } - unfold_operator = torch.nn.Unfold(**fold_params) - x = unfold_operator(x) - x = x.view((b, c_in, tile_size, tile_size, -1)) - - # inference - x_out_stack = [] - for tile_id in range(0, x.shape[-1], batch_size): - - # process input - next_tile_id = min(tile_id + batch_size, x.shape[-1]) - x_in = x[:, :, :, :, tile_id: next_tile_id] - x_in = x_in.to(device=device, dtype=torch_dtype) - x_in = x_in.permute(4, 0, 1, 2, 3) - x_in = x_in.view((x_in.shape[0]*x_in.shape[1], x_in.shape[2], x_in.shape[3], x_in.shape[4])) - - # process output - x_out = forward_fn(x_in) - x_out = x_out.view((next_tile_id - tile_id, b, x_out.shape[1], x_out.shape[2], x_out.shape[3])) - x_out = x_out.permute(1, 2, 3, 4, 0) - x_out = x_out.to(device=inter_device, dtype=inter_dtype) - x_out_stack.append(x_out) - - x = torch.concat(x_out_stack, dim=-1) - - # untile - in2out_scale = x.shape[2] / tile_size - h_out, w_out = int(h_in * in2out_scale), int(w_in * in2out_scale) - - mask = self.mask(int(tile_size * in2out_scale), int(tile_size * in2out_scale), int(tile_stride * in2out_scale * 0.5)) - mask = mask.to(device=inter_device, dtype=inter_dtype) - mask = mask.reshape((1, 1, mask.shape[0], mask.shape[1], 1)) - x = x * mask - - fold_params = { - "kernel_size": (int(tile_size * in2out_scale), int(tile_size * in2out_scale)), - "stride": (int(tile_stride * in2out_scale), int(tile_stride * in2out_scale)) - } - fold_operator = torch.nn.Fold(output_size=(h_out, w_out), **fold_params) - divisor = fold_operator(mask.repeat(1, 1, 1, 1, x.shape[-1]).view(b, -1, x.shape[-1])) - - x = x.view((b, -1, x.shape[-1])) - x = fold_operator(x) / divisor - x = x.to(device=device, dtype=torch_dtype) - - return x - - class TileWorker: def __init__(self): pass diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py index 67ff24e..91c2fa7 100644 --- a/diffsynth/pipelines/dancer.py +++ b/diffsynth/pipelines/dancer.py @@ -1,6 +1,6 @@ import torch from ..models import SDUNet, SDMotionModel -from ..models.sd_unet import PushBlock, PopBlock +from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock from ..models.tiler import TileWorker from ..controlnets import MultiControlNetManager @@ -75,25 +75,14 @@ def lets_dance( hidden_states_output = [] for batch_id in range(0, sample.shape[0], unet_batch_size): batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) - if tiled: - _, _, inter_height, _ = hidden_states.shape - resize_scale = inter_height / height - hidden_states = TileWorker().tiled_forward( - lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0], - hidden_states_input[batch_id: batch_id_], - int(tile_size * resize_scale), - int(tile_stride * resize_scale), - tile_device=hidden_states.device, - tile_dtype=hidden_states.dtype - ) - else: - hidden_states, _, _, _ = block( - hidden_states_input[batch_id: batch_id_], - time_emb, - text_emb[batch_id: batch_id_], - res_stack, - cross_frame_attention=cross_frame_attention - ) + hidden_states, _, _, _ = block( + hidden_states_input[batch_id: batch_id_], + time_emb, + text_emb[batch_id: batch_id_], + res_stack, + cross_frame_attention=cross_frame_attention, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) hidden_states_output.append(hidden_states) hidden_states = torch.concat(hidden_states_output, dim=0) # 4.2 AnimateDiff