update TileWorker for better visual quality

This commit is contained in:
Artiprocher
2024-01-09 22:29:17 +08:00
parent 552355692d
commit 8a460497fa
4 changed files with 47 additions and 122 deletions

View File

@@ -1,6 +1,5 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler
from .tiler import Tiler
class SDXLUNet(torch.nn.Module):
@@ -108,13 +107,10 @@ class SDXLUNet(torch.nn.Module):
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled:
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
block, hidden_states, time_emb, text_emb, res_stack,
height, width, tile_size, tile_stride
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
# 4. output
hidden_states = self.conv_norm_out(hidden_states)
@@ -123,23 +119,6 @@ class SDXLUNet(torch.nn.Module):
return hidden_states
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
resize_scale = inter_height / height
hidden_states = Tiler()(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
inter_device=hidden_states.device,
inter_dtype=hidden_states.dtype
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
return hidden_states, time_emb, text_emb, res_stack
def state_dict_converter(self):
return SDXLUNetStateDictConverter()