Merge pull request #250 from modelscope/flux-controlnet

Flux controlnet
This commit is contained in:
Zhongjie Duan
2024-10-25 10:58:37 +08:00
committed by GitHub
12 changed files with 936 additions and 82 deletions

View File

@@ -35,6 +35,7 @@ from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
@@ -80,6 +81,10 @@ model_loader_configs = [
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -199,7 +204,6 @@ preset_models_on_huggingface = {
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
# Translator
"opus-mt-zh-en": [
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
@@ -341,6 +345,24 @@ preset_models_on_modelscope = {
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"Annotators:Depth": [
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
],
"Annotators:Softedge": [
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
],
"Annotators:Lineart": [
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
],
"Annotators:Normal": [
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
],
"Annotators:Openpose": [
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
@@ -482,6 +504,30 @@ preset_models_on_modelscope = {
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
],
},
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
],
"jasperai/Flux.1-dev-Controlnet-Depth": [
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
],
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
],
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
],
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -541,10 +587,23 @@ Preset_model_id: TypeAlias = Literal[
"ControlNet_union_sdxl_promax",
"FLUX.1-dev",
"FLUX.1-schnell",
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
"jasperai/Flux.1-dev-Controlnet-Depth",
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
"jasperai/Flux.1-dev-Controlnet-Upscaler",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt",
"OmostPrompt",
"ESRGAN_x4",
"RIFE",
"CogVideoX-5B",
"Annotators:Depth",
"Annotators:Softedge",
"Annotators:Lineart",
"Annotators:Normal",
"Annotators:Openpose",
]

View File

@@ -1,2 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
from .processors import Annotator

View File

@@ -4,10 +4,11 @@ from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
self.skip_processor = skip_processor
class ControlNetUnit:
@@ -60,3 +61,29 @@ class MultiControlNetManager:
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack
class FluxMultiControlNetManager(MultiControlNetManager):
def __init__(self, controlnet_units=[]):
super().__init__(controlnet_units=controlnet_units)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
return processed_image
def __call__(self, conditionings, **kwargs):
res_stack, single_res_stack = None, None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
res_stack_ = [res * scale for res in res_stack_]
single_res_stack_ = [res * scale for res in single_res_stack_]
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack

View File

@@ -3,37 +3,42 @@ import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
)
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile":
self.processor = None
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if not skip_processor:
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "normal":
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
self.processor = None
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def __call__(self, image):
def __call__(self, image, mask=None):
width, height = image.size
if self.processor_id == "openpose":
kwargs = {

View File

@@ -0,0 +1,226 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock
from .utils import hash_state_dict_keys
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
self.mode_dict = mode_dict
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
if len(res_stack) == 0:
return [torch.zeros_like(hidden_states)] * num_blocks
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
return aligned_res_stack
def forward(
self,
hidden_states,
controlnet_conditioning,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
processor_id=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
controlnet_res_stack = []
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_res_stack.append(controlnet_block(hidden_states))
controlnet_single_res_stack = []
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
return controlnet_res_stack, controlnet_single_res_stack
@staticmethod
def state_dict_converter():
return FluxControlNetStateDictConverter()
class FluxControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
if hash_value == "78d18b9101345ff695f312e7e62538c0":
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
extra_kwargs = {"num_single_blocks": 0}
elif hash_value == "52357cb26250681367488a8954c271e8":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,7 +1,4 @@
import os, torch, hashlib, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
import os, torch, json, importlib
from typing import List
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
@@ -50,45 +47,7 @@ from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):

View File

@@ -107,6 +107,60 @@ class TileWorker:
class FastTileWorker:
def __init__(self):
pass
def build_mask(self, data, is_bound):
_, _, H, W = data.shape
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else h + 1,
pad if is_bound[1] else H - h,
pad if is_bound[2] else w + 1,
pad if is_bound[3] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
mask = rearrange(mask, "H W -> 1 H W")
return mask
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
B, C, H, W = model_input.shape
border_width = int(tile_stride*0.5) if border_width is None else border_width
weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = H - tile_size, H
if w_ > W: w, w_ = W - tile_size, W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
# Forward
hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
weight[:, :, hl:hr, wl:wr] += mask
values /= weight
return values
class TileWorker2Dto3D:
"""
Process 3D tensors, but only enable TileWorker on 2D.

View File

@@ -1,6 +1,7 @@
import torch, os
from safetensors import safe_open
from contextlib import contextmanager
import hashlib
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
@@ -142,3 +143,40 @@ def search_for_files(folder, extensions):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -47,9 +47,15 @@ class BasePipeline(torch.nn.Module):
return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback):
noise_pred_global = inference_callback(prompt_emb_global)
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
if special_kwargs is None:
noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
else:
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred

View File

@@ -139,6 +139,8 @@ def lets_dance_xl(
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
if add_text_embeds.shape[0] != sample.shape[0]:
add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
# 1. ControlNet
controlnet_insert_block_id = 22
@@ -204,7 +206,7 @@ def lets_dance_xl(
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
time_emb[batch_id: batch_id_],
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,

View File

@@ -1,9 +1,14 @@
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import List
import torch
from tqdm import tqdm
import numpy as np
from PIL import Image
from ..models.tiler import FastTileWorker
@@ -19,14 +24,15 @@ class FluxImagePipeline(BasePipeline):
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
self.controlnet: FluxMultiControlNetManager = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit")
@@ -36,14 +42,25 @@ class FluxImagePipeline(BasePipeline):
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
model_manager.fetch_model("flux_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe
@@ -71,17 +88,62 @@ class FluxImagePipeline(BasePipeline):
return {"image_ids": latent_image_ids, "guidance": guidance}
def apply_controlnet_mask_on_latents(self, latents, mask):
mask = (self.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = mask.to(dtype=self.torch_dtype, device=self.device)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, image, mask):
mask = mask.resize(image.size)
mask = self.preprocess_image(mask).mean(dim=[0, 1])
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
if isinstance(controlnet_image, Image.Image):
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
controlnet_frames = []
for i in range(len(self.controlnet.processors)):
# image annotator
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
# image to tensor
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
# vae encoder
image = self.encode_image(image, **tiler_kwargs)
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
# store it
controlnet_frames.append(image)
return controlnet_frames
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts= None,
masks= None,
mask_scales= None,
local_prompts=None,
masks=None,
mask_scales=None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
input_image=None,
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
denoising_strength=1.0,
height=1024,
width=1024,
@@ -123,19 +185,38 @@ class FluxImagePipeline(BasePipeline):
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Prepare ControlNets
if controlnet_image is not None:
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
if len(masks) > 0 and controlnet_inpaint_mask is not None:
print("The controlnet_inpaint_mask will be overridden by masks.")
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
else:
local_controlnet_kwargs = None
else:
controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
# Denoise
self.load_models_to_device(['dit'])
self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {}
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -155,3 +236,101 @@ class FluxImagePipeline(BasePipeline):
# Offload all models
self.load_models_to_device([])
return image
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
pooled_prompt_emb=None,
guidance=None,
text_ids=None,
image_ids=None,
controlnet_frames=None,
tiled=False,
tile_size=128,
tile_stride=64,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
return lets_dance_flux(
dit=dit,
controlnet=controlnet,
hidden_states=hidden_states[:, :, hl: hr, wl: wr],
timestep=timestep,
prompt_emb=prompt_emb,
pooled_prompt_emb=pooled_prompt_emb,
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames],
tiled=False,
**kwargs
)
return FastTileWorker().tiled_forward(
flux_forward_fn,
hidden_states,
tile_size=tile_size,
tile_stride=tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
"prompt_emb": prompt_emb,
"pooled_prompt_emb": pooled_prompt_emb,
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
}
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states

View File

@@ -0,0 +1,299 @@
from diffsynth import ModelManager, FluxImagePipeline, ControlNetConfigUnit, download_models, download_customized_models
import torch
from PIL import Image
import numpy as np
def example_1():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
])
image_1 = pipe(
prompt="a photo of a cat, highly detailed",
height=768, width=768,
seed=0
)
image_1.save("image_1.png")
image_2 = pipe(
prompt="a photo of a cat, highly detailed",
controlnet_image=image_1.resize((2048, 2048)),
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=1
)
image_2.save("image_2.png")
def example_2():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
])
image_1 = pipe(
prompt="a beautiful Chinese girl, delicate skin texture",
height=768, width=768,
seed=2
)
image_1.save("image_3.png")
image_2 = pipe(
prompt="a beautiful Chinese girl, delicate skin texture",
controlnet_image=image_1.resize((2048, 2048)),
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=3
)
image_2.save("image_4.png")
def example_3():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
ControlNetConfigUnit(
processor_id="depth",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
])
image_1 = pipe(
prompt="a cat is running",
height=1024, width=1024,
seed=4
)
image_1.save("image_5.png")
image_2 = pipe(
prompt="sunshine, a cat is running",
controlnet_image=image_1,
height=1024, width=1024,
seed=5
)
image_2.save("image_6.png")
def example_4():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
ControlNetConfigUnit(
processor_id="depth",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
])
image_1 = pipe(
prompt="a beautiful Asian girl, full body, red dress, summer",
height=1024, width=1024,
seed=6
)
image_1.save("image_7.png")
image_2 = pipe(
prompt="a beautiful Asian girl, full body, red dress, winter",
controlnet_image=image_1,
height=1024, width=1024,
seed=7
)
image_2.save("image_8.png")
def example_5():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
])
image_1 = pipe(
prompt="a cat sitting on a chair",
height=1024, width=1024,
seed=8
)
image_1.save("image_9.png")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[100:350, 350: -300] = 255
mask = Image.fromarray(mask)
mask.save("mask_9.png")
image_2 = pipe(
prompt="a cat sitting on a chair, wearing sunglasses",
controlnet_image=image_1, controlnet_inpaint_mask=mask,
height=1024, width=1024,
seed=9
)
image_2.save("image_10.png")
def example_6():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[
"FLUX.1-dev",
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"
])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
ControlNetConfigUnit(
processor_id="normal",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
scale=0.6
),
])
image_1 = pipe(
prompt="a beautiful Asian woman looking at the sky, wearing a blue t-shirt.",
height=1024, width=1024,
seed=10
)
image_1.save("image_11.png")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[-400:, 10:-40] = 255
mask = Image.fromarray(mask)
mask.save("mask_11.png")
image_2 = pipe(
prompt="a beautiful Asian woman looking at the sky, wearing a yellow t-shirt.",
controlnet_image=image_1, controlnet_inpaint_mask=mask,
height=1024, width=1024,
seed=11
)
image_2.save("image_12.png")
def example_7():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[
"FLUX.1-dev",
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"jasperai/Flux.1-dev-Controlnet-Upscaler",
])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.5
),
])
image_1 = pipe(
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
height=1024, width=1024,
seed=100
)
image_1.save("image_13.png")
mask_global = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_global = Image.fromarray(mask_global)
mask_global.save("mask_13_global.png")
mask_1 = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_1[300:-100, 30: 450] = 255
mask_1 = Image.fromarray(mask_1)
mask_1.save("mask_13_1.png")
mask_2 = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_2[500:-100, -400:] = 255
mask_2[-200:-100, -500:-400] = 255
mask_2 = Image.fromarray(mask_2)
mask_2.save("mask_13_2.png")
image_2 = pipe(
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
controlnet_image=image_1, controlnet_inpaint_mask=mask_global,
local_prompts=["an orange cat, highly detailed", "a girl wearing a red camisole"], masks=[mask_1, mask_2], mask_scales=[10.0, 10.0],
height=1024, width=1024,
seed=101
)
image_2.save("image_14.png")
model_manager.load_lora("models/lora/FLUX-dev-lora-AntiBlur.safetensors", lora_alpha=2)
image_3 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. clear background.",
negative_prompt="blur, blurry",
input_image=image_2, denoising_strength=0.7,
height=1024, width=1024,
cfg_scale=2.0, num_inference_steps=50,
seed=102
)
image_3.save("image_15.png")
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
])
image_4 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
controlnet_image=image_3.resize((2048, 2048)),
input_image=image_3.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=103
)
image_4.save("image_16.png")
image_5 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
controlnet_image=image_4.resize((4096, 4096)),
input_image=image_4.resize((4096, 4096)), denoising_strength=0.99,
height=4096, width=4096, tiled=True,
seed=104
)
image_5.save("image_17.png")
download_models(["Annotators:Depth", "Annotators:Normal"])
download_customized_models(
model_id="LiblibAI/FLUX.1-dev-LoRA-AntiBlur",
origin_file_path="FLUX-dev-lora-AntiBlur.safetensors",
local_dir="models/lora"
)
example_1()
example_2()
example_3()
example_4()
example_5()
example_6()
example_7()