From 71b17a3a539b6d7099005513620a4c380442ca78 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 12 Nov 2024 19:20:17 +0800 Subject: [PATCH] update mask blur --- diffsynth/pipelines/base.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index fb8813f..c14c5d6 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -36,17 +36,18 @@ class BasePipeline(torch.nn.Module): return video - def merge_latents(self, value, latents, masks, scales, blur_kernel_size=3, blur_sigma=1.0): - blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) - height, width = value.shape[-2:] - weight = torch.ones_like(value) - for latent, mask, scale in zip(latents, masks, scales): - mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 - mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device) - mask = blur(mask) - value += latent * mask * scale - weight += mask * scale - value /= weight + def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0): + if len(latents) > 0: + blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) + height, width = value.shape[-2:] + weight = torch.ones_like(value) + for latent, mask, scale in zip(latents, masks, scales): + mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device) + mask = blur(mask) + value += latent * mask * scale + weight += mask * scale + value /= weight return value