mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
180 lines
7.4 KiB
Python
180 lines
7.4 KiB
Python
import torch
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
class TileWorker:
|
|
def __init__(self):
|
|
pass
|
|
|
|
|
|
def mask(self, height, width, border_width):
|
|
# Create a mask with shape (height, width).
|
|
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
|
x = torch.arange(height).repeat(width, 1).T
|
|
y = torch.arange(width).repeat(height, 1)
|
|
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
|
mask = (mask / border_width).clip(0, 1)
|
|
return mask
|
|
|
|
|
|
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
|
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
|
batch_size, channel, _, _ = model_input.shape
|
|
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
|
unfold_operator = torch.nn.Unfold(
|
|
kernel_size=(tile_size, tile_size),
|
|
stride=(tile_stride, tile_stride)
|
|
)
|
|
model_input = unfold_operator(model_input)
|
|
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
|
|
|
return model_input
|
|
|
|
|
|
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
|
# Call y=forward_fn(x) for each tile
|
|
tile_num = model_input.shape[-1]
|
|
model_output_stack = []
|
|
|
|
for tile_id in range(0, tile_num, tile_batch_size):
|
|
|
|
# process input
|
|
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
|
x = model_input[:, :, :, :, tile_id: tile_id_]
|
|
x = x.to(device=inference_device, dtype=inference_dtype)
|
|
x = rearrange(x, "b c h w n -> (n b) c h w")
|
|
|
|
# process output
|
|
y = forward_fn(x)
|
|
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
|
y = y.to(device=tile_device, dtype=tile_dtype)
|
|
model_output_stack.append(y)
|
|
|
|
model_output = torch.concat(model_output_stack, dim=-1)
|
|
return model_output
|
|
|
|
|
|
def io_scale(self, model_output, tile_size):
|
|
# Determine the size modification happend in forward_fn
|
|
# We only consider the same scale on height and width.
|
|
io_scale = model_output.shape[2] / tile_size
|
|
return io_scale
|
|
|
|
|
|
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
|
# The reversed function of tile
|
|
mask = self.mask(tile_size, tile_size, border_width)
|
|
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
|
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
|
model_output = model_output * mask
|
|
|
|
fold_operator = torch.nn.Fold(
|
|
output_size=(height, width),
|
|
kernel_size=(tile_size, tile_size),
|
|
stride=(tile_stride, tile_stride)
|
|
)
|
|
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
|
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
|
model_output = fold_operator(model_output) / fold_operator(mask)
|
|
|
|
return model_output
|
|
|
|
|
|
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
|
# Prepare
|
|
inference_device, inference_dtype = model_input.device, model_input.dtype
|
|
height, width = model_input.shape[2], model_input.shape[3]
|
|
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
|
|
|
# tile
|
|
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
|
|
|
# inference
|
|
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
|
|
|
# resize
|
|
io_scale = self.io_scale(model_output, tile_size)
|
|
height, width = int(height*io_scale), int(width*io_scale)
|
|
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
|
border_width = int(border_width*io_scale)
|
|
|
|
# untile
|
|
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
|
|
|
# Done!
|
|
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
|
return model_output
|
|
|
|
|
|
|
|
class TileWorker2Dto3D:
|
|
"""
|
|
Process 3D tensors, but only enable TileWorker on 2D.
|
|
"""
|
|
def __init__(self):
|
|
pass
|
|
|
|
|
|
def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
|
|
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
|
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
|
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
|
border_width = (H + W) // 4 if border_width is None else border_width
|
|
pad = torch.ones_like(h) * border_width
|
|
mask = torch.stack([
|
|
pad if is_bound[0] else t + 1,
|
|
pad if is_bound[1] else T - t,
|
|
pad if is_bound[2] else h + 1,
|
|
pad if is_bound[3] else H - h,
|
|
pad if is_bound[4] else w + 1,
|
|
pad if is_bound[5] else W - w
|
|
]).min(dim=0).values
|
|
mask = mask.clip(1, border_width)
|
|
mask = (mask / border_width).to(dtype=dtype, device=device)
|
|
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
|
return mask
|
|
|
|
|
|
def tiled_forward(
|
|
self,
|
|
forward_fn,
|
|
model_input,
|
|
tile_size, tile_stride,
|
|
tile_device="cpu", tile_dtype=torch.float32,
|
|
computation_device="cuda", computation_dtype=torch.float32,
|
|
border_width=None, scales=[1, 1, 1, 1],
|
|
progress_bar=lambda x:x
|
|
):
|
|
B, C, T, H, W = model_input.shape
|
|
scale_C, scale_T, scale_H, scale_W = scales
|
|
tile_size_H, tile_size_W = tile_size
|
|
tile_stride_H, tile_stride_W = tile_stride
|
|
|
|
value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
|
|
weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
|
|
|
|
# Split tasks
|
|
tasks = []
|
|
for h in range(0, H, tile_stride_H):
|
|
for w in range(0, W, tile_stride_W):
|
|
if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
|
|
continue
|
|
h_, w_ = h + tile_size_H, w + tile_size_W
|
|
if h_ > H: h, h_ = max(H - tile_size_H, 0), H
|
|
if w_ > W: w, w_ = max(W - tile_size_W, 0), W
|
|
tasks.append((h, h_, w, w_))
|
|
|
|
# Run
|
|
for hl, hr, wl, wr in progress_bar(tasks):
|
|
mask = self.build_mask(
|
|
int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
|
|
tile_dtype, tile_device,
|
|
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
|
|
border_width=border_width
|
|
)
|
|
grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
|
|
grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
|
|
value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
|
|
weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
|
|
value = value / weight
|
|
return value |