support flux-controlnet

This commit is contained in:
Artiprocher
2024-10-22 18:52:24 +08:00
parent 72ed76e89e
commit 07d70a6a56
9 changed files with 522 additions and 76 deletions

View File

@@ -35,6 +35,7 @@ from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT from ..models.cog_dit import CogDiT
@@ -80,6 +81,10 @@ model_loader_configs = [
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"), (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.

View File

@@ -1,2 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
from .processors import Annotator from .processors import Annotator

View File

@@ -4,10 +4,11 @@ from .processors import Processor_id
class ControlNetConfigUnit: class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0): def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
self.processor_id = processor_id self.processor_id = processor_id
self.model_path = model_path self.model_path = model_path
self.scale = scale self.scale = scale
self.skip_processor = skip_processor
class ControlNetUnit: class ControlNetUnit:
@@ -60,3 +61,29 @@ class MultiControlNetManager:
else: else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)] res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack return res_stack
class FluxMultiControlNetManager(MultiControlNetManager):
def __init__(self, controlnet_units=[]):
super().__init__(controlnet_units=controlnet_units)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
return processed_image
def __call__(self, conditionings, **kwargs):
res_stack, single_res_stack = None, None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
res_stack_ = [res * scale for res in res_stack_]
single_res_stack_ = [res * scale for res in single_res_stack_]
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack

View File

@@ -3,37 +3,42 @@ import warnings
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
from controlnet_aux.processor import ( from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
) )
Processor_id: TypeAlias = Literal[ Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
] ]
class Annotator: class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'): def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if processor_id == "canny": if not skip_processor:
self.processor = CannyDetector() if processor_id == "canny":
elif processor_id == "depth": self.processor = CannyDetector()
self.processor = MidasDetector.from_pretrained(model_path).to(device) elif processor_id == "depth":
elif processor_id == "softedge": self.processor = MidasDetector.from_pretrained(model_path).to(device)
self.processor = HEDdetector.from_pretrained(model_path).to(device) elif processor_id == "softedge":
elif processor_id == "lineart": self.processor = HEDdetector.from_pretrained(model_path).to(device)
self.processor = LineartDetector.from_pretrained(model_path).to(device) elif processor_id == "lineart":
elif processor_id == "lineart_anime": self.processor = LineartDetector.from_pretrained(model_path).to(device)
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) elif processor_id == "lineart_anime":
elif processor_id == "openpose": self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
self.processor = OpenposeDetector.from_pretrained(model_path).to(device) elif processor_id == "openpose":
elif processor_id == "tile": self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
self.processor = None elif processor_id == "normal":
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
else: else:
raise ValueError(f"Unsupported processor_id: {processor_id}") self.processor = None
self.processor_id = processor_id self.processor_id = processor_id
self.detect_resolution = detect_resolution self.detect_resolution = detect_resolution
def __call__(self, image): def __call__(self, image, mask=None):
width, height = image.size width, height = image.size
if self.processor_id == "openpose": if self.processor_id == "openpose":
kwargs = { kwargs = {

View File

@@ -0,0 +1,226 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock
from .utils import hash_state_dict_keys
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
self.mode_dict = mode_dict
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
if len(res_stack) == 0:
return [torch.zeros_like(hidden_states)] * num_blocks
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
return aligned_res_stack
def forward(
self,
hidden_states,
controlnet_conditioning,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
processor_id=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
controlnet_res_stack = []
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_res_stack.append(controlnet_block(hidden_states))
controlnet_single_res_stack = []
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
return controlnet_res_stack, controlnet_single_res_stack
@staticmethod
def state_dict_converter():
return FluxControlNetStateDictConverter()
class FluxControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
if hash_value == "78d18b9101345ff695f312e7e62538c0":
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
extra_kwargs = {"num_single_blocks": 0}
elif hash_value == "52357cb26250681367488a8954c271e8":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,7 +1,4 @@
import os, torch, hashlib, json, importlib import os, torch, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
from typing import List from typing import List
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
@@ -50,45 +47,7 @@ from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device): def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):

View File

@@ -1,6 +1,7 @@
import torch, os import torch, os
from safetensors import safe_open from safetensors import safe_open
from contextlib import contextmanager from contextlib import contextmanager
import hashlib
@contextmanager @contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
@@ -142,3 +143,40 @@ def search_for_files(folder, extensions):
files.append(folder) files.append(folder)
break break
return files return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -1,9 +1,13 @@
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler from ..schedulers import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from typing import List
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import numpy as np
from PIL import Image
@@ -19,14 +23,15 @@ class FluxImagePipeline(BasePipeline):
self.dit: FluxDiT = None self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None self.vae_encoder: FluxVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder'] self.controlnet: FluxMultiControlNetManager = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
def denoising_model(self): def denoising_model(self):
return self.dit return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]): def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1") self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2") self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit") self.dit = model_manager.fetch_model("flux_dit")
@@ -36,14 +41,25 @@ class FluxImagePipeline(BasePipeline):
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes) self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
model_manager.fetch_model("flux_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units)
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None): def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline( pipe = FluxImagePipeline(
device=model_manager.device if device is None else device, device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype, torch_dtype=model_manager.torch_dtype,
) )
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes) pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe return pipe
@@ -71,17 +87,61 @@ class FluxImagePipeline(BasePipeline):
return {"image_ids": latent_image_ids, "guidance": guidance} return {"image_ids": latent_image_ids, "guidance": guidance}
def apply_controlnet_mask_on_latents(self, latents, mask):
mask = (self.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = mask.to(dtype=self.torch_dtype, device=self.device)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, image, mask):
mask = mask.resize(image.size)
mask = self.preprocess_image(mask).mean(dim=[0, 1])
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
if isinstance(controlnet_image, Image.Image):
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
controlnet_frames = []
for i in range(len(self.controlnet.processors)):
# image annotator
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
# image to tensor
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
# vae encoder
image = self.encode_image(image, **tiler_kwargs)
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
# store it
controlnet_frames.append(image)
return controlnet_frames
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt, prompt,
local_prompts= None, local_prompts=None,
masks= None, masks=None,
mask_scales= None, mask_scales=None,
negative_prompt="", negative_prompt="",
cfg_scale=1.0, cfg_scale=1.0,
embedded_guidance=3.5, embedded_guidance=3.5,
input_image=None, input_image=None,
controlnet_image=None,
controlnet_inpaint_mask=None,
denoising_strength=1.0, denoising_strength=1.0,
height=1024, height=1024,
width=1024, width=1024,
@@ -123,19 +183,29 @@ class FluxImagePipeline(BasePipeline):
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Prepare ControlNets
if controlnet_image is not None:
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise # Denoise
self.load_models_to_device(['dit']) self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance # Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit( inference_callback = lambda prompt_emb_posi: lets_dance_flux(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
) )
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = self.dit( noise_pred_nega = lets_dance_flux(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -155,3 +225,75 @@ class FluxImagePipeline(BasePipeline):
# Offload all models # Offload all models
self.load_models_to_device([]) self.load_models_to_device([])
return image return image
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
pooled_prompt_emb=None,
guidance=None,
text_ids=None,
image_ids=None,
controlnet_frames=None,
tiled=False,
tile_size=128,
tile_stride=64,
**kwargs
):
# ControlNet
if controlnet is not None and controlnet_frames is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
"prompt_emb": prompt_emb,
"pooled_prompt_emb": pooled_prompt_emb,
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
}
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states

View File

@@ -0,0 +1,44 @@
from diffsynth.models.flux_controlnet import FluxControlNet
from diffsynth import load_state_dict, ModelManager, FluxImagePipeline, hash_state_dict_keys, ControlNetConfigUnit
import torch
from PIL import Image
import numpy as np
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev"])
model_manager.load_models([
"models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
"models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors",
"models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
"models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors",
"models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", scale=0.3),
ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors", scale=0.1),
ControlNetConfigUnit(processor_id="normal", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", scale=0.1),
ControlNetConfigUnit(processor_id="tile", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", scale=0.05),
ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors", scale=0.01),
ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", scale=0.01),
ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors", scale=0.05),
ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors", scale=0.3),
])
torch.manual_seed(0)
control_image = Image.open("controlnet_input.jpeg").resize((768, 1024))
control_mask = Image.open("controlnet_mask.jpg").resize((768, 1024))
prompt = "masterpiece, best quality, a beautiful girl, CG, blue sky, long red hair, black clothes"
negative_prompt = "oil painting, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
embedded_guidance=3.5, num_inference_steps=50,
height=1024, width=768,
controlnet_image=control_image, controlnet_inpaint_mask=control_mask,
)
image.save("image.jpg")