update mask blur

This commit is contained in:
Artiprocher
2024-11-12 19:20:17 +08:00
parent d46b8b8fd7
commit 71b17a3a53

View File

@@ -36,17 +36,18 @@ class BasePipeline(torch.nn.Module):
return video return video
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=3, blur_sigma=1.0): def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) if len(latents) > 0:
height, width = value.shape[-2:] blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
weight = torch.ones_like(value) height, width = value.shape[-2:]
for latent, mask, scale in zip(latents, masks, scales): weight = torch.ones_like(value)
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 for latent, mask, scale in zip(latents, masks, scales):
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device) mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = blur(mask) mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
value += latent * mask * scale mask = blur(mask)
weight += mask * scale value += latent * mask * scale
value /= weight weight += mask * scale
value /= weight
return value return value