Merge pull request #250 from modelscope/flux-controlnet

Flux controlnet
This commit is contained in:
Zhongjie Duan
2024-10-25 10:58:37 +08:00
committed by GitHub
12 changed files with 936 additions and 82 deletions

View File

@@ -35,6 +35,7 @@ from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT from ..models.cog_dit import CogDiT
@@ -80,6 +81,10 @@ model_loader_configs = [
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"), (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -199,7 +204,6 @@ preset_models_on_huggingface = {
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
], ],
# Translator # Translator
"opus-mt-zh-en": [ "opus-mt-zh-en": [
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
@@ -341,6 +345,24 @@ preset_models_on_modelscope = {
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"), ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators") ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
], ],
"Annotators:Depth": [
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
],
"Annotators:Softedge": [
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
],
"Annotators:Lineart": [
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
],
"Annotators:Normal": [
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
],
"Annotators:Openpose": [
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
],
# AnimateDiff # AnimateDiff
"AnimateDiff_v2": [ "AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"), ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
@@ -482,6 +504,30 @@ preset_models_on_modelscope = {
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors" "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
], ],
}, },
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
],
"jasperai/Flux.1-dev-Controlnet-Depth": [
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
],
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
],
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
],
# ESRGAN # ESRGAN
"ESRGAN_x4": [ "ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -541,10 +587,23 @@ Preset_model_id: TypeAlias = Literal[
"ControlNet_union_sdxl_promax", "ControlNet_union_sdxl_promax",
"FLUX.1-dev", "FLUX.1-dev",
"FLUX.1-schnell", "FLUX.1-schnell",
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
"jasperai/Flux.1-dev-Controlnet-Depth",
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
"jasperai/Flux.1-dev-Controlnet-Upscaler",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt", "QwenPrompt",
"OmostPrompt", "OmostPrompt",
"ESRGAN_x4", "ESRGAN_x4",
"RIFE", "RIFE",
"CogVideoX-5B", "CogVideoX-5B",
"Annotators:Depth",
"Annotators:Softedge",
"Annotators:Lineart",
"Annotators:Normal",
"Annotators:Openpose",
] ]

View File

@@ -1,2 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
from .processors import Annotator from .processors import Annotator

View File

@@ -4,10 +4,11 @@ from .processors import Processor_id
class ControlNetConfigUnit: class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0): def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
self.processor_id = processor_id self.processor_id = processor_id
self.model_path = model_path self.model_path = model_path
self.scale = scale self.scale = scale
self.skip_processor = skip_processor
class ControlNetUnit: class ControlNetUnit:
@@ -60,3 +61,29 @@ class MultiControlNetManager:
else: else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)] res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack return res_stack
class FluxMultiControlNetManager(MultiControlNetManager):
def __init__(self, controlnet_units=[]):
super().__init__(controlnet_units=controlnet_units)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
return processed_image
def __call__(self, conditionings, **kwargs):
res_stack, single_res_stack = None, None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
res_stack_ = [res * scale for res in res_stack_]
single_res_stack_ = [res * scale for res in single_res_stack_]
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack

View File

@@ -3,37 +3,42 @@ import warnings
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
from controlnet_aux.processor import ( from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
) )
Processor_id: TypeAlias = Literal[ Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
] ]
class Annotator: class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'): def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if processor_id == "canny": if not skip_processor:
self.processor = CannyDetector() if processor_id == "canny":
elif processor_id == "depth": self.processor = CannyDetector()
self.processor = MidasDetector.from_pretrained(model_path).to(device) elif processor_id == "depth":
elif processor_id == "softedge": self.processor = MidasDetector.from_pretrained(model_path).to(device)
self.processor = HEDdetector.from_pretrained(model_path).to(device) elif processor_id == "softedge":
elif processor_id == "lineart": self.processor = HEDdetector.from_pretrained(model_path).to(device)
self.processor = LineartDetector.from_pretrained(model_path).to(device) elif processor_id == "lineart":
elif processor_id == "lineart_anime": self.processor = LineartDetector.from_pretrained(model_path).to(device)
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) elif processor_id == "lineart_anime":
elif processor_id == "openpose": self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
self.processor = OpenposeDetector.from_pretrained(model_path).to(device) elif processor_id == "openpose":
elif processor_id == "tile": self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
self.processor = None elif processor_id == "normal":
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
else: else:
raise ValueError(f"Unsupported processor_id: {processor_id}") self.processor = None
self.processor_id = processor_id self.processor_id = processor_id
self.detect_resolution = detect_resolution self.detect_resolution = detect_resolution
def __call__(self, image): def __call__(self, image, mask=None):
width, height = image.size width, height = image.size
if self.processor_id == "openpose": if self.processor_id == "openpose":
kwargs = { kwargs = {

View File

@@ -0,0 +1,226 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock
from .utils import hash_state_dict_keys
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
self.mode_dict = mode_dict
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
if len(res_stack) == 0:
return [torch.zeros_like(hidden_states)] * num_blocks
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
return aligned_res_stack
def forward(
self,
hidden_states,
controlnet_conditioning,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
processor_id=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
controlnet_res_stack = []
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_res_stack.append(controlnet_block(hidden_states))
controlnet_single_res_stack = []
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
return controlnet_res_stack, controlnet_single_res_stack
@staticmethod
def state_dict_converter():
return FluxControlNetStateDictConverter()
class FluxControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
if hash_value == "78d18b9101345ff695f312e7e62538c0":
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
extra_kwargs = {"num_single_blocks": 0}
elif hash_value == "52357cb26250681367488a8954c271e8":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,7 +1,4 @@
import os, torch, hashlib, json, importlib import os, torch, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
from typing import List from typing import List
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
@@ -50,45 +47,7 @@ from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device): def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):

View File

@@ -107,6 +107,60 @@ class TileWorker:
class FastTileWorker:
def __init__(self):
pass
def build_mask(self, data, is_bound):
_, _, H, W = data.shape
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else h + 1,
pad if is_bound[1] else H - h,
pad if is_bound[2] else w + 1,
pad if is_bound[3] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
mask = rearrange(mask, "H W -> 1 H W")
return mask
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
B, C, H, W = model_input.shape
border_width = int(tile_stride*0.5) if border_width is None else border_width
weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = H - tile_size, H
if w_ > W: w, w_ = W - tile_size, W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
# Forward
hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
weight[:, :, hl:hr, wl:wr] += mask
values /= weight
return values
class TileWorker2Dto3D: class TileWorker2Dto3D:
""" """
Process 3D tensors, but only enable TileWorker on 2D. Process 3D tensors, but only enable TileWorker on 2D.

View File

@@ -1,6 +1,7 @@
import torch, os import torch, os
from safetensors import safe_open from safetensors import safe_open
from contextlib import contextmanager from contextlib import contextmanager
import hashlib
@contextmanager @contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
@@ -142,3 +143,40 @@ def search_for_files(folder, extensions):
files.append(folder) files.append(folder)
break break
return files return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -47,9 +47,15 @@ class BasePipeline(torch.nn.Module):
return value return value
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback): def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
noise_pred_global = inference_callback(prompt_emb_global) if special_kwargs is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] noise_pred_global = inference_callback(prompt_emb_global)
else:
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
if special_local_kwargs_list is None:
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
else:
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
return noise_pred return noise_pred

View File

@@ -139,6 +139,8 @@ def lets_dance_xl(
# 0. Text embedding alignment (only for video processing) # 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]: if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1) encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
if add_text_embeds.shape[0] != sample.shape[0]:
add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
# 1. ControlNet # 1. ControlNet
controlnet_insert_block_id = 22 controlnet_insert_block_id = 22
@@ -204,7 +206,7 @@ def lets_dance_xl(
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block( hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_], hidden_states_input[batch_id: batch_id_],
time_emb, time_emb[batch_id: batch_id_],
text_emb[batch_id: batch_id_], text_emb[batch_id: batch_id_],
res_stack, res_stack,
cross_frame_attention=cross_frame_attention, cross_frame_attention=cross_frame_attention,

View File

@@ -1,9 +1,14 @@
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler from ..schedulers import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from typing import List
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import numpy as np
from PIL import Image
from ..models.tiler import FastTileWorker
@@ -19,14 +24,15 @@ class FluxImagePipeline(BasePipeline):
self.dit: FluxDiT = None self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None self.vae_encoder: FluxVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder'] self.controlnet: FluxMultiControlNetManager = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
def denoising_model(self): def denoising_model(self):
return self.dit return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]): def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1") self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2") self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit") self.dit = model_manager.fetch_model("flux_dit")
@@ -36,14 +42,25 @@ class FluxImagePipeline(BasePipeline):
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes) self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
model_manager.fetch_model("flux_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units)
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None): def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline( pipe = FluxImagePipeline(
device=model_manager.device if device is None else device, device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype, torch_dtype=model_manager.torch_dtype,
) )
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes) pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe return pipe
@@ -71,17 +88,62 @@ class FluxImagePipeline(BasePipeline):
return {"image_ids": latent_image_ids, "guidance": guidance} return {"image_ids": latent_image_ids, "guidance": guidance}
def apply_controlnet_mask_on_latents(self, latents, mask):
mask = (self.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = mask.to(dtype=self.torch_dtype, device=self.device)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, image, mask):
mask = mask.resize(image.size)
mask = self.preprocess_image(mask).mean(dim=[0, 1])
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
if isinstance(controlnet_image, Image.Image):
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
controlnet_frames = []
for i in range(len(self.controlnet.processors)):
# image annotator
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
# image to tensor
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
# vae encoder
image = self.encode_image(image, **tiler_kwargs)
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
# store it
controlnet_frames.append(image)
return controlnet_frames
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt, prompt,
local_prompts= None, local_prompts=None,
masks= None, masks=None,
mask_scales= None, mask_scales=None,
negative_prompt="", negative_prompt="",
cfg_scale=1.0, cfg_scale=1.0,
embedded_guidance=3.5, embedded_guidance=3.5,
input_image=None, input_image=None,
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
denoising_strength=1.0, denoising_strength=1.0,
height=1024, height=1024,
width=1024, width=1024,
@@ -123,19 +185,38 @@ class FluxImagePipeline(BasePipeline):
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Prepare ControlNets
if controlnet_image is not None:
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
if len(masks) > 0 and controlnet_inpaint_mask is not None:
print("The controlnet_inpaint_mask will be overridden by masks.")
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
else:
local_controlnet_kwargs = None
else:
controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
# Denoise # Denoise
self.load_models_to_device(['dit']) self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance # Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit( inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs
) )
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = self.dit( negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {}
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs,
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -155,3 +236,101 @@ class FluxImagePipeline(BasePipeline):
# Offload all models # Offload all models
self.load_models_to_device([]) self.load_models_to_device([])
return image return image
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
pooled_prompt_emb=None,
guidance=None,
text_ids=None,
image_ids=None,
controlnet_frames=None,
tiled=False,
tile_size=128,
tile_stride=64,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
return lets_dance_flux(
dit=dit,
controlnet=controlnet,
hidden_states=hidden_states[:, :, hl: hr, wl: wr],
timestep=timestep,
prompt_emb=prompt_emb,
pooled_prompt_emb=pooled_prompt_emb,
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames],
tiled=False,
**kwargs
)
return FastTileWorker().tiled_forward(
flux_forward_fn,
hidden_states,
tile_size=tile_size,
tile_stride=tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
"prompt_emb": prompt_emb,
"pooled_prompt_emb": pooled_prompt_emb,
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
}
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states

View File

@@ -0,0 +1,299 @@
from diffsynth import ModelManager, FluxImagePipeline, ControlNetConfigUnit, download_models, download_customized_models
import torch
from PIL import Image
import numpy as np
def example_1():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
])
image_1 = pipe(
prompt="a photo of a cat, highly detailed",
height=768, width=768,
seed=0
)
image_1.save("image_1.png")
image_2 = pipe(
prompt="a photo of a cat, highly detailed",
controlnet_image=image_1.resize((2048, 2048)),
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=1
)
image_2.save("image_2.png")
def example_2():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
])
image_1 = pipe(
prompt="a beautiful Chinese girl, delicate skin texture",
height=768, width=768,
seed=2
)
image_1.save("image_3.png")
image_2 = pipe(
prompt="a beautiful Chinese girl, delicate skin texture",
controlnet_image=image_1.resize((2048, 2048)),
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=3
)
image_2.save("image_4.png")
def example_3():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
ControlNetConfigUnit(
processor_id="depth",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
])
image_1 = pipe(
prompt="a cat is running",
height=1024, width=1024,
seed=4
)
image_1.save("image_5.png")
image_2 = pipe(
prompt="sunshine, a cat is running",
controlnet_image=image_1,
height=1024, width=1024,
seed=5
)
image_2.save("image_6.png")
def example_4():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
ControlNetConfigUnit(
processor_id="depth",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.3
),
])
image_1 = pipe(
prompt="a beautiful Asian girl, full body, red dress, summer",
height=1024, width=1024,
seed=6
)
image_1.save("image_7.png")
image_2 = pipe(
prompt="a beautiful Asian girl, full body, red dress, winter",
controlnet_image=image_1,
height=1024, width=1024,
seed=7
)
image_2.save("image_8.png")
def example_5():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
])
image_1 = pipe(
prompt="a cat sitting on a chair",
height=1024, width=1024,
seed=8
)
image_1.save("image_9.png")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[100:350, 350: -300] = 255
mask = Image.fromarray(mask)
mask.save("mask_9.png")
image_2 = pipe(
prompt="a cat sitting on a chair, wearing sunglasses",
controlnet_image=image_1, controlnet_inpaint_mask=mask,
height=1024, width=1024,
seed=9
)
image_2.save("image_10.png")
def example_6():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[
"FLUX.1-dev",
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"
])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
ControlNetConfigUnit(
processor_id="normal",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
scale=0.6
),
])
image_1 = pipe(
prompt="a beautiful Asian woman looking at the sky, wearing a blue t-shirt.",
height=1024, width=1024,
seed=10
)
image_1.save("image_11.png")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[-400:, 10:-40] = 255
mask = Image.fromarray(mask)
mask.save("mask_11.png")
image_2 = pipe(
prompt="a beautiful Asian woman looking at the sky, wearing a yellow t-shirt.",
controlnet_image=image_1, controlnet_inpaint_mask=mask,
height=1024, width=1024,
seed=11
)
image_2.save("image_12.png")
def example_7():
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[
"FLUX.1-dev",
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"jasperai/Flux.1-dev-Controlnet-Upscaler",
])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="inpaint",
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
scale=0.9
),
ControlNetConfigUnit(
processor_id="canny",
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
scale=0.5
),
])
image_1 = pipe(
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
height=1024, width=1024,
seed=100
)
image_1.save("image_13.png")
mask_global = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_global = Image.fromarray(mask_global)
mask_global.save("mask_13_global.png")
mask_1 = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_1[300:-100, 30: 450] = 255
mask_1 = Image.fromarray(mask_1)
mask_1.save("mask_13_1.png")
mask_2 = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask_2[500:-100, -400:] = 255
mask_2[-200:-100, -500:-400] = 255
mask_2 = Image.fromarray(mask_2)
mask_2.save("mask_13_2.png")
image_2 = pipe(
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
controlnet_image=image_1, controlnet_inpaint_mask=mask_global,
local_prompts=["an orange cat, highly detailed", "a girl wearing a red camisole"], masks=[mask_1, mask_2], mask_scales=[10.0, 10.0],
height=1024, width=1024,
seed=101
)
image_2.save("image_14.png")
model_manager.load_lora("models/lora/FLUX-dev-lora-AntiBlur.safetensors", lora_alpha=2)
image_3 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. clear background.",
negative_prompt="blur, blurry",
input_image=image_2, denoising_strength=0.7,
height=1024, width=1024,
cfg_scale=2.0, num_inference_steps=50,
seed=102
)
image_3.save("image_15.png")
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(
processor_id="tile",
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
scale=0.7
),
])
image_4 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
controlnet_image=image_3.resize((2048, 2048)),
input_image=image_3.resize((2048, 2048)), denoising_strength=0.99,
height=2048, width=2048, tiled=True,
seed=103
)
image_4.save("image_16.png")
image_5 = pipe(
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
controlnet_image=image_4.resize((4096, 4096)),
input_image=image_4.resize((4096, 4096)), denoising_strength=0.99,
height=4096, width=4096, tiled=True,
seed=104
)
image_5.save("image_17.png")
download_models(["Annotators:Depth", "Annotators:Normal"])
download_customized_models(
model_id="LiblibAI/FLUX.1-dev-LoRA-AntiBlur",
origin_file_path="FLUX-dev-lora-AntiBlur.safetensors",
local_dir="models/lora"
)
example_1()
example_2()
example_3()
example_4()
example_5()
example_6()
example_7()