mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
fix a bug in sliding window inference
This commit is contained in:
@@ -1012,12 +1012,16 @@ class TemporalTiler_BCTHW:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
def build_1d_mask(length, left_bound, right_bound, border_width):
|
||||||
x = torch.ones((length,))
|
x = torch.ones((length,))
|
||||||
|
if border_width == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
shift = 0.5
|
||||||
if not left_bound:
|
if not left_bound:
|
||||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
x[:border_width] = (torch.arange(border_width) + shift) / border_width
|
||||||
if not right_bound:
|
if not right_bound:
|
||||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def build_mask(self, data, is_bound, border_width):
|
def build_mask(self, data, is_bound, border_width):
|
||||||
@@ -1047,7 +1051,7 @@ class TemporalTiler_BCTHW:
|
|||||||
mask = self.build_mask(
|
mask = self.build_mask(
|
||||||
model_output,
|
model_output,
|
||||||
is_bound=(t == 0, t_ == T),
|
is_bound=(t == 0, t_ == T),
|
||||||
border_width=(sliding_window_size - sliding_window_stride + 1,)
|
border_width=(sliding_window_size - sliding_window_stride,)
|
||||||
).to(device=data_device, dtype=data_dtype)
|
).to(device=data_device, dtype=data_dtype)
|
||||||
value[:, :, t: t_, :, :] += model_output * mask
|
value[:, :, t: t_, :, :] += model_output * mask
|
||||||
weight[:, :, t: t_, :, :] += mask
|
weight[:, :, t: t_, :, :] += mask
|
||||||
|
|||||||
Reference in New Issue
Block a user