support mask blur

This commit is contained in:
Artiprocher
2024-11-11 18:59:55 +08:00
parent 7e97a96840
commit 8d1d1536d3
4 changed files with 13 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import GaussianBlur
@@ -35,14 +36,16 @@ class BasePipeline(torch.nn.Module):
return video
def merge_latents(self, value, latents, masks, scales):
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
height, width = value.shape[-2:]
weight = torch.ones_like(value)
for latent, mask, scale in zip(latents, masks, scales):
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, latent.shape[1], 1, 1)
value[mask] += latent[mask] * scale
weight[mask] += scale
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
mask = blur(mask)
value += latent * mask * scale
weight += mask * scale
value /= weight
return value

View File

@@ -257,6 +257,7 @@ def lets_dance_flux(
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
return lets_dance_flux(
dit=dit,
controlnet=controlnet,
@@ -267,7 +268,7 @@ def lets_dance_flux(
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames],
controlnet_frames=tiled_controlnet_frames,
tiled=False,
**kwargs
)

View File

@@ -22,7 +22,9 @@ class LightningModel(LightningModelForT2ILoRA):
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
model_manager.load_lora(preset_lora_path)
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)

View File

@@ -1,4 +1,5 @@
torch>=2.0.0
torchvision
cupy-cuda12x
transformers==4.44.1
controlnet-aux==0.0.7