mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support flux-controlnet
This commit is contained in:
@@ -35,6 +35,7 @@ from ..models.hunyuan_dit import HunyuanDiT
|
||||
from ..models.flux_dit import FluxDiT
|
||||
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
|
||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from ..models.cog_dit import CogDiT
|
||||
@@ -80,6 +81,10 @@ model_loader_configs = [
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
||||
from .processors import Annotator
|
||||
|
||||
@@ -4,10 +4,11 @@ from .processors import Processor_id
|
||||
|
||||
|
||||
class ControlNetConfigUnit:
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
||||
self.processor_id = processor_id
|
||||
self.model_path = model_path
|
||||
self.scale = scale
|
||||
self.skip_processor = skip_processor
|
||||
|
||||
|
||||
class ControlNetUnit:
|
||||
@@ -60,3 +61,29 @@ class MultiControlNetManager:
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
return res_stack
|
||||
|
||||
|
||||
class FluxMultiControlNetManager(MultiControlNetManager):
|
||||
def __init__(self, controlnet_units=[]):
|
||||
super().__init__(controlnet_units=controlnet_units)
|
||||
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
processed_image = [processor(image) for processor in self.processors]
|
||||
else:
|
||||
processed_image = [self.processors[processor_id](image)]
|
||||
return processed_image
|
||||
|
||||
def __call__(self, conditionings, **kwargs):
|
||||
res_stack, single_res_stack = None, None
|
||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
||||
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
single_res_stack_ = [res * scale for res in single_res_stack_]
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
single_res_stack = single_res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
||||
return res_stack, single_res_stack
|
||||
|
||||
@@ -3,37 +3,42 @@ import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||
if not skip_processor:
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "normal":
|
||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
self.processor = None
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
def __call__(self, image):
|
||||
def __call__(self, image, mask=None):
|
||||
width, height = image.size
|
||||
if self.processor_id == "openpose":
|
||||
kwargs = {
|
||||
|
||||
226
diffsynth/models/flux_controlnet.py
Normal file
226
diffsynth/models/flux_controlnet.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
|
||||
class FluxControlNet(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
||||
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
||||
|
||||
self.mode_dict = mode_dict
|
||||
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
||||
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
||||
if len(res_stack) == 0:
|
||||
return [torch.zeros_like(hidden_states)] * num_blocks
|
||||
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
||||
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
||||
return aligned_res_stack
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
controlnet_conditioning,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
processor_id=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
||||
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
||||
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
||||
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
||||
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
||||
|
||||
controlnet_res_stack = []
|
||||
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_res_stack.append(controlnet_block(hidden_states))
|
||||
|
||||
controlnet_single_res_stack = []
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
||||
|
||||
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
||||
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
||||
|
||||
return controlnet_res_stack, controlnet_single_res_stack
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxControlNetStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxControlNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
hash_value = hash_state_dict_keys(state_dict)
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
||||
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
||||
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
||||
extra_kwargs = {"num_single_blocks": 0}
|
||||
elif hash_value == "52357cb26250681367488a8954c271e8":
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
return state_dict_, extra_kwargs
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,7 +1,4 @@
|
||||
import os, torch, hashlib, json, importlib
|
||||
from safetensors import safe_open
|
||||
from torch import Tensor
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import os, torch, json, importlib
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
||||
@@ -50,45 +47,7 @@ from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict, init_weights_on_device
|
||||
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
||||
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
from contextlib import contextmanager
|
||||
import hashlib
|
||||
|
||||
@contextmanager
|
||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||
@@ -142,3 +143,40 @@ def search_for_files(folder, extensions):
|
||||
files.append(folder)
|
||||
break
|
||||
return files
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
@@ -1,9 +1,13 @@
|
||||
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
|
||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import FluxPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
@@ -19,14 +23,15 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.dit: FluxDiT = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
self.controlnet: FluxMultiControlNetManager = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("flux_dit")
|
||||
@@ -36,14 +41,25 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
|
||||
|
||||
# ControlNets
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
|
||||
model_manager.fetch_model("flux_controlnet", config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = FluxMultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -71,17 +87,61 @@ class FluxImagePipeline(BasePipeline):
|
||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
||||
|
||||
|
||||
def apply_controlnet_mask_on_latents(self, latents, mask):
|
||||
mask = (self.preprocess_image(mask) + 1) / 2
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
mask = mask.to(dtype=self.torch_dtype, device=self.device)
|
||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
||||
latents = torch.concat([latents, mask], dim=1)
|
||||
return latents
|
||||
|
||||
|
||||
def apply_controlnet_mask_on_image(self, image, mask):
|
||||
mask = mask.resize(image.size)
|
||||
mask = self.preprocess_image(mask).mean(dim=[0, 1])
|
||||
image = np.array(image)
|
||||
image[mask > 0] = 0
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
|
||||
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
|
||||
if isinstance(controlnet_image, Image.Image):
|
||||
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
|
||||
|
||||
controlnet_frames = []
|
||||
for i in range(len(self.controlnet.processors)):
|
||||
# image annotator
|
||||
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
|
||||
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
|
||||
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
|
||||
|
||||
# image to tensor
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# vae encoder
|
||||
image = self.encode_image(image, **tiler_kwargs)
|
||||
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
|
||||
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
|
||||
|
||||
# store it
|
||||
controlnet_frames.append(image)
|
||||
return controlnet_frames
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts= None,
|
||||
masks= None,
|
||||
mask_scales= None,
|
||||
local_prompts=None,
|
||||
masks=None,
|
||||
mask_scales=None,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=3.5,
|
||||
input_image=None,
|
||||
controlnet_image=None,
|
||||
controlnet_inpaint_mask=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
@@ -123,19 +183,29 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit'])
|
||||
self.load_models_to_device(['dit', 'controlnet'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
inference_callback = lambda prompt_emb_posi: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
|
||||
inference_callback = lambda prompt_emb_posi: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -155,3 +225,75 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
|
||||
|
||||
def lets_dance_flux(
|
||||
dit: FluxDiT,
|
||||
controlnet: FluxMultiControlNetManager = None,
|
||||
hidden_states=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
pooled_prompt_emb=None,
|
||||
guidance=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
controlnet_frames=None,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
controlnet_extra_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"prompt_emb": prompt_emb,
|
||||
"pooled_prompt_emb": pooled_prompt_emb,
|
||||
"guidance": guidance,
|
||||
"text_ids": text_ids,
|
||||
"image_ids": image_ids,
|
||||
"tiled": tiled,
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride,
|
||||
}
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = dit.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
|
||||
if dit.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = dit.patchify(hidden_states)
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
|
||||
# Single Blocks
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = dit.final_proj_out(hidden_states)
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
44
examples/image_synthesis/flux_controlnet.py
Normal file
44
examples/image_synthesis/flux_controlnet.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from diffsynth.models.flux_controlnet import FluxControlNet
|
||||
from diffsynth import load_state_dict, ModelManager, FluxImagePipeline, hash_state_dict_keys, ControlNetConfigUnit
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_models([
|
||||
"models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", scale=0.3),
|
||||
ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors", scale=0.1),
|
||||
ControlNetConfigUnit(processor_id="normal", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", scale=0.1),
|
||||
ControlNetConfigUnit(processor_id="tile", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", scale=0.05),
|
||||
ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors", scale=0.01),
|
||||
ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", scale=0.01),
|
||||
ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors", scale=0.05),
|
||||
ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors", scale=0.3),
|
||||
])
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
control_image = Image.open("controlnet_input.jpeg").resize((768, 1024))
|
||||
control_mask = Image.open("controlnet_mask.jpg").resize((768, 1024))
|
||||
|
||||
prompt = "masterpiece, best quality, a beautiful girl, CG, blue sky, long red hair, black clothes"
|
||||
negative_prompt = "oil painting, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
embedded_guidance=3.5, num_inference_steps=50,
|
||||
height=1024, width=768,
|
||||
controlnet_image=control_image, controlnet_inpaint_mask=control_mask,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
Reference in New Issue
Block a user