mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
update TileWorker for better visual quality
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from ..models import SDUNet, SDMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock
|
||||
from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock
|
||||
from ..models.tiler import TileWorker
|
||||
from ..controlnets import MultiControlNetManager
|
||||
|
||||
@@ -75,25 +75,14 @@ def lets_dance(
|
||||
hidden_states_output = []
|
||||
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
||||
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
||||
if tiled:
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0],
|
||||
hidden_states_input[batch_id: batch_id_],
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(
|
||||
hidden_states_input[batch_id: batch_id_],
|
||||
time_emb,
|
||||
text_emb[batch_id: batch_id_],
|
||||
res_stack,
|
||||
cross_frame_attention=cross_frame_attention
|
||||
)
|
||||
hidden_states, _, _, _ = block(
|
||||
hidden_states_input[batch_id: batch_id_],
|
||||
time_emb,
|
||||
text_emb[batch_id: batch_id_],
|
||||
res_stack,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
hidden_states_output.append(hidden_states)
|
||||
hidden_states = torch.concat(hidden_states_output, dim=0)
|
||||
# 4.2 AnimateDiff
|
||||
|
||||
Reference in New Issue
Block a user