mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
support mask blur
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchvision.transforms import GaussianBlur
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -35,14 +36,16 @@ class BasePipeline(torch.nn.Module):
|
|||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
def merge_latents(self, value, latents, masks, scales):
|
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
||||||
|
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
||||||
height, width = value.shape[-2:]
|
height, width = value.shape[-2:]
|
||||||
weight = torch.ones_like(value)
|
weight = torch.ones_like(value)
|
||||||
for latent, mask, scale in zip(latents, masks, scales):
|
for latent, mask, scale in zip(latents, masks, scales):
|
||||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
||||||
mask = mask.repeat(1, latent.shape[1], 1, 1)
|
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
||||||
value[mask] += latent[mask] * scale
|
mask = blur(mask)
|
||||||
weight[mask] += scale
|
value += latent * mask * scale
|
||||||
|
weight += mask * scale
|
||||||
value /= weight
|
value /= weight
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|||||||
@@ -257,6 +257,7 @@ def lets_dance_flux(
|
|||||||
):
|
):
|
||||||
if tiled:
|
if tiled:
|
||||||
def flux_forward_fn(hl, hr, wl, wr):
|
def flux_forward_fn(hl, hr, wl, wr):
|
||||||
|
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
||||||
return lets_dance_flux(
|
return lets_dance_flux(
|
||||||
dit=dit,
|
dit=dit,
|
||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
@@ -267,7 +268,7 @@ def lets_dance_flux(
|
|||||||
guidance=guidance,
|
guidance=guidance,
|
||||||
text_ids=text_ids,
|
text_ids=text_ids,
|
||||||
image_ids=None,
|
image_ids=None,
|
||||||
controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames],
|
controlnet_frames=tiled_controlnet_frames,
|
||||||
tiled=False,
|
tiled=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
model_manager.load_models(pretrained_weights[1:])
|
model_manager.load_models(pretrained_weights[1:])
|
||||||
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||||
if preset_lora_path is not None:
|
if preset_lora_path is not None:
|
||||||
model_manager.load_lora(preset_lora_path)
|
preset_lora_path = preset_lora_path.split(",")
|
||||||
|
for path in preset_lora_path:
|
||||||
|
model_manager.load_lora(path)
|
||||||
|
|
||||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
torch>=2.0.0
|
torch>=2.0.0
|
||||||
|
torchvision
|
||||||
cupy-cuda12x
|
cupy-cuda12x
|
||||||
transformers==4.44.1
|
transformers==4.44.1
|
||||||
controlnet-aux==0.0.7
|
controlnet-aux==0.0.7
|
||||||
|
|||||||
Reference in New Issue
Block a user