mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support mask blur
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.transforms import GaussianBlur
|
||||
|
||||
|
||||
|
||||
@@ -35,14 +36,16 @@ class BasePipeline(torch.nn.Module):
|
||||
return video
|
||||
|
||||
|
||||
def merge_latents(self, value, latents, masks, scales):
|
||||
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
||||
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
||||
height, width = value.shape[-2:]
|
||||
weight = torch.ones_like(value)
|
||||
for latent, mask, scale in zip(latents, masks, scales):
|
||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1)
|
||||
value[mask] += latent[mask] * scale
|
||||
weight[mask] += scale
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
||||
mask = blur(mask)
|
||||
value += latent * mask * scale
|
||||
weight += mask * scale
|
||||
value /= weight
|
||||
return value
|
||||
|
||||
|
||||
@@ -257,6 +257,7 @@ def lets_dance_flux(
|
||||
):
|
||||
if tiled:
|
||||
def flux_forward_fn(hl, hr, wl, wr):
|
||||
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
||||
return lets_dance_flux(
|
||||
dit=dit,
|
||||
controlnet=controlnet,
|
||||
@@ -267,7 +268,7 @@ def lets_dance_flux(
|
||||
guidance=guidance,
|
||||
text_ids=text_ids,
|
||||
image_ids=None,
|
||||
controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames],
|
||||
controlnet_frames=tiled_controlnet_frames,
|
||||
tiled=False,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user