fix a bug in sliding window inference

This commit is contained in:
ziyannchen
2025-07-20 11:13:20 +00:00
parent 05c6b49b90
commit c05b1a2fd0

View File

@@ -1012,12 +1012,16 @@ class TemporalTiler_BCTHW:
def __init__(self):
pass
def build_1d_mask(self, length, left_bound, right_bound, border_width):
def build_1d_mask(length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if border_width == 0:
return x
shift = 0.5
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
x[:border_width] = (torch.arange(border_width) + shift) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
@@ -1047,7 +1051,7 @@ class TemporalTiler_BCTHW:
mask = self.build_mask(
model_output,
is_bound=(t == 0, t_ == T),
border_width=(sliding_window_size - sliding_window_stride + 1,)
border_width=(sliding_window_size - sliding_window_stride,)
).to(device=data_device, dtype=data_dtype)
value[:, :, t: t_, :, :] += model_output * mask
weight[:, :, t: t_, :, :] += mask