mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
Merge pull request #263 from modelscope/super-alignment
support mask blur
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchvision.transforms import GaussianBlur
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -35,14 +36,16 @@ class BasePipeline(torch.nn.Module):
|
|||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
def merge_latents(self, value, latents, masks, scales):
|
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
||||||
|
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
||||||
height, width = value.shape[-2:]
|
height, width = value.shape[-2:]
|
||||||
weight = torch.ones_like(value)
|
weight = torch.ones_like(value)
|
||||||
for latent, mask, scale in zip(latents, masks, scales):
|
for latent, mask, scale in zip(latents, masks, scales):
|
||||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
||||||
mask = mask.repeat(1, latent.shape[1], 1, 1)
|
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
||||||
value[mask] += latent[mask] * scale
|
mask = blur(mask)
|
||||||
weight[mask] += scale
|
value += latent * mask * scale
|
||||||
|
weight += mask * scale
|
||||||
value /= weight
|
value /= weight
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
model_manager.load_models(pretrained_weights[1:])
|
model_manager.load_models(pretrained_weights[1:])
|
||||||
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||||
if preset_lora_path is not None:
|
if preset_lora_path is not None:
|
||||||
model_manager.load_lora(preset_lora_path)
|
preset_lora_path = preset_lora_path.split(",")
|
||||||
|
for path in preset_lora_path:
|
||||||
|
model_manager.load_lora(path)
|
||||||
|
|
||||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
torch>=2.0.0
|
torch>=2.0.0
|
||||||
|
torchvision
|
||||||
cupy-cuda12x
|
cupy-cuda12x
|
||||||
transformers==4.46.2
|
transformers==4.46.2
|
||||||
controlnet-aux==0.0.7
|
controlnet-aux==0.0.7
|
||||||
|
|||||||
Reference in New Issue
Block a user