support flux-controlnet

This commit is contained in:
Artiprocher
2024-10-22 18:52:24 +08:00
parent 72ed76e89e
commit 07d70a6a56
9 changed files with 522 additions and 76 deletions

View File

@@ -35,6 +35,7 @@ from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
@@ -80,6 +81,10 @@ model_loader_configs = [
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.

View File

@@ -1,2 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
from .processors import Annotator

View File

@@ -4,10 +4,11 @@ from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
self.skip_processor = skip_processor
class ControlNetUnit:
@@ -60,3 +61,29 @@ class MultiControlNetManager:
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack
class FluxMultiControlNetManager(MultiControlNetManager):
def __init__(self, controlnet_units=[]):
super().__init__(controlnet_units=controlnet_units)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
return processed_image
def __call__(self, conditionings, **kwargs):
res_stack, single_res_stack = None, None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
res_stack_ = [res * scale for res in res_stack_]
single_res_stack_ = [res * scale for res in single_res_stack_]
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack

View File

@@ -3,37 +3,42 @@ import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
)
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile":
self.processor = None
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if not skip_processor:
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "normal":
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
self.processor = None
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def __call__(self, image):
def __call__(self, image, mask=None):
width, height = image.size
if self.processor_id == "openpose":
kwargs = {

View File

@@ -0,0 +1,226 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock
from .utils import hash_state_dict_keys
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
self.mode_dict = mode_dict
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def patchify(self, hidden_states):
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
return hidden_states
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
if len(res_stack) == 0:
return [torch.zeros_like(hidden_states)] * num_blocks
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
return aligned_res_stack
def forward(
self,
hidden_states,
controlnet_conditioning,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
processor_id=None,
tiled=False, tile_size=128, tile_stride=64,
**kwargs
):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
controlnet_res_stack = []
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_res_stack.append(controlnet_block(hidden_states))
controlnet_single_res_stack = []
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
return controlnet_res_stack, controlnet_single_res_stack
@staticmethod
def state_dict_converter():
return FluxControlNetStateDictConverter()
class FluxControlNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
hash_value = hash_state_dict_keys(state_dict)
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
state_dict_[name] = param
else:
state_dict_[name] = param
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
if hash_value == "78d18b9101345ff695f312e7e62538c0":
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
extra_kwargs = {"num_single_blocks": 0}
elif hash_value == "52357cb26250681367488a8954c271e8":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,7 +1,4 @@
import os, torch, hashlib, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
import os, torch, json, importlib
from typing import List
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
@@ -50,45 +47,7 @@ from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):

View File

@@ -1,6 +1,7 @@
import torch, os
from safetensors import safe_open
from contextlib import contextmanager
import hashlib
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
@@ -142,3 +143,40 @@ def search_for_files(folder, extensions):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -1,9 +1,13 @@
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
from typing import List
import torch
from tqdm import tqdm
import numpy as np
from PIL import Image
@@ -19,14 +23,15 @@ class FluxImagePipeline(BasePipeline):
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
self.controlnet: FluxMultiControlNetManager = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("flux_dit")
@@ -36,14 +41,25 @@ class FluxImagePipeline(BasePipeline):
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
model_manager.fetch_model("flux_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe
@@ -71,17 +87,61 @@ class FluxImagePipeline(BasePipeline):
return {"image_ids": latent_image_ids, "guidance": guidance}
def apply_controlnet_mask_on_latents(self, latents, mask):
mask = (self.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = mask.to(dtype=self.torch_dtype, device=self.device)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, image, mask):
mask = mask.resize(image.size)
mask = self.preprocess_image(mask).mean(dim=[0, 1])
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
if isinstance(controlnet_image, Image.Image):
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
controlnet_frames = []
for i in range(len(self.controlnet.processors)):
# image annotator
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
# image to tensor
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
# vae encoder
image = self.encode_image(image, **tiler_kwargs)
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
# store it
controlnet_frames.append(image)
return controlnet_frames
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts= None,
masks= None,
mask_scales= None,
local_prompts=None,
masks=None,
mask_scales=None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
input_image=None,
controlnet_image=None,
controlnet_inpaint_mask=None,
denoising_strength=1.0,
height=1024,
width=1024,
@@ -123,19 +183,29 @@ class FluxImagePipeline(BasePipeline):
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Prepare ControlNets
if controlnet_image is not None:
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
self.load_models_to_device(['dit'])
self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi: self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
inference_callback = lambda prompt_emb_posi: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
)
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -155,3 +225,75 @@ class FluxImagePipeline(BasePipeline):
# Offload all models
self.load_models_to_device([])
return image
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
pooled_prompt_emb=None,
guidance=None,
text_ids=None,
image_ids=None,
controlnet_frames=None,
tiled=False,
tile_size=128,
tile_stride=64,
**kwargs
):
# ControlNet
if controlnet is not None and controlnet_frames is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
"prompt_emb": prompt_emb,
"pooled_prompt_emb": pooled_prompt_emb,
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
}
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states

View File

@@ -0,0 +1,44 @@
from diffsynth.models.flux_controlnet import FluxControlNet
from diffsynth import load_state_dict, ModelManager, FluxImagePipeline, hash_state_dict_keys, ControlNetConfigUnit
import torch
from PIL import Image
import numpy as np
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev"])
model_manager.load_models([
"models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
"models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors",
"models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
"models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors",
"models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", scale=0.3),
ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors", scale=0.1),
ControlNetConfigUnit(processor_id="normal", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", scale=0.1),
ControlNetConfigUnit(processor_id="tile", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", scale=0.05),
ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors", scale=0.01),
ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", scale=0.01),
ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors", scale=0.05),
ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors", scale=0.3),
])
torch.manual_seed(0)
control_image = Image.open("controlnet_input.jpeg").resize((768, 1024))
control_mask = Image.open("controlnet_mask.jpg").resize((768, 1024))
prompt = "masterpiece, best quality, a beautiful girl, CG, blue sky, long red hair, black clothes"
negative_prompt = "oil painting, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
embedded_guidance=3.5, num_inference_steps=50,
height=1024, width=768,
controlnet_image=control_image, controlnet_inpaint_mask=control_mask,
)
image.save("image.jpg")