mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support mask blur
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.transforms import GaussianBlur
|
||||
|
||||
|
||||
|
||||
@@ -35,14 +36,16 @@ class BasePipeline(torch.nn.Module):
|
||||
return video
|
||||
|
||||
|
||||
def merge_latents(self, value, latents, masks, scales):
|
||||
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
||||
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
||||
height, width = value.shape[-2:]
|
||||
weight = torch.ones_like(value)
|
||||
for latent, mask, scale in zip(latents, masks, scales):
|
||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1)
|
||||
value[mask] += latent[mask] * scale
|
||||
weight[mask] += scale
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
||||
mask = blur(mask)
|
||||
value += latent * mask * scale
|
||||
weight += mask * scale
|
||||
value /= weight
|
||||
return value
|
||||
|
||||
|
||||
@@ -257,6 +257,7 @@ def lets_dance_flux(
|
||||
):
|
||||
if tiled:
|
||||
def flux_forward_fn(hl, hr, wl, wr):
|
||||
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
||||
return lets_dance_flux(
|
||||
dit=dit,
|
||||
controlnet=controlnet,
|
||||
@@ -267,7 +268,7 @@ def lets_dance_flux(
|
||||
guidance=guidance,
|
||||
text_ids=text_ids,
|
||||
image_ids=None,
|
||||
controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames],
|
||||
controlnet_frames=tiled_controlnet_frames,
|
||||
tiled=False,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -22,7 +22,9 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
model_manager.load_models(pretrained_weights[1:])
|
||||
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||
if preset_lora_path is not None:
|
||||
model_manager.load_lora(preset_lora_path)
|
||||
preset_lora_path = preset_lora_path.split(",")
|
||||
for path in preset_lora_path:
|
||||
model_manager.load_lora(path)
|
||||
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
torch>=2.0.0
|
||||
torchvision
|
||||
cupy-cuda12x
|
||||
transformers==4.44.1
|
||||
controlnet-aux==0.0.7
|
||||
|
||||
Reference in New Issue
Block a user