From 8d1d1536d3eeb0132563b75a89a9219448c5e087 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 11 Nov 2024 18:59:55 +0800 Subject: [PATCH] support mask blur --- diffsynth/pipelines/base.py | 11 +++++++---- diffsynth/pipelines/flux_image.py | 3 ++- examples/train/flux/train_flux_lora.py | 4 +++- requirements.txt | 1 + 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index b968bb6..d76eadb 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -1,6 +1,7 @@ import torch import numpy as np from PIL import Image +from torchvision.transforms import GaussianBlur @@ -35,14 +36,16 @@ class BasePipeline(torch.nn.Module): return video - def merge_latents(self, value, latents, masks, scales): + def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0): + blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) height, width = value.shape[-2:] weight = torch.ones_like(value) for latent, mask, scale in zip(latents, masks, scales): mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 - mask = mask.repeat(1, latent.shape[1], 1, 1) - value[mask] += latent[mask] * scale - weight[mask] += scale + mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device) + mask = blur(mask) + value += latent * mask * scale + weight += mask * scale value /= weight return value diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 89d730f..517fb48 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -257,6 +257,7 @@ def lets_dance_flux( ): if tiled: def flux_forward_fn(hl, hr, wl, wr): + tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None return lets_dance_flux( dit=dit, controlnet=controlnet, @@ -267,7 +268,7 @@ def lets_dance_flux( guidance=guidance, text_ids=text_ids, image_ids=None, - controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames], + controlnet_frames=tiled_controlnet_frames, tiled=False, **kwargs ) diff --git a/examples/train/flux/train_flux_lora.py b/examples/train/flux/train_flux_lora.py index 0bf118f..4efeed3 100644 --- a/examples/train/flux/train_flux_lora.py +++ b/examples/train/flux/train_flux_lora.py @@ -22,7 +22,9 @@ class LightningModel(LightningModelForT2ILoRA): model_manager.load_models(pretrained_weights[1:]) model_manager.load_model(pretrained_weights[0], torch_dtype=quantize) if preset_lora_path is not None: - model_manager.load_lora(preset_lora_path) + preset_lora_path = preset_lora_path.split(",") + for path in preset_lora_path: + model_manager.load_lora(path) self.pipe = FluxImagePipeline.from_model_manager(model_manager) diff --git a/requirements.txt b/requirements.txt index df207cc..e53cc11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torch>=2.0.0 +torchvision cupy-cuda12x transformers==4.44.1 controlnet-aux==0.0.7