update TileWorker for better visual quality

This commit is contained in:
Artiprocher
2024-01-09 22:29:17 +08:00
parent 552355692d
commit 8a460497fa
4 changed files with 47 additions and 122 deletions

View File

@@ -1,6 +1,6 @@
import torch
from ..models import SDUNet, SDMotionModel
from ..models.sd_unet import PushBlock, PopBlock
from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock
from ..models.tiler import TileWorker
from ..controlnets import MultiControlNetManager
@@ -75,25 +75,14 @@ def lets_dance(
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
if tiled:
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0],
hidden_states_input[batch_id: batch_id_],
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention
)
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff