Merge pull request #263 from modelscope/super-alignment

support mask blur
This commit is contained in:
Zhongjie Duan
2024-11-11 19:24:30 +08:00
committed by GitHub
3 changed files with 11 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
@@ -35,14 +36,16 @@ class BasePipeline(torch.nn.Module):
return video
def merge_latents(self, value, latents, masks, scales):
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1)
value[mask] += latent[mask] * scale
weight[mask] += scale
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value