From 07d70a6a56d6668d340f219ab12afcf78c0b7cbe Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 22 Oct 2024 18:52:24 +0800 Subject: [PATCH 1/6] support flux-controlnet --- diffsynth/configs/model_config.py | 5 + diffsynth/controlnets/__init__.py | 2 +- diffsynth/controlnets/controlnet_unit.py | 29 ++- diffsynth/controlnets/processors.py | 43 ++-- diffsynth/models/flux_controlnet.py | 226 ++++++++++++++++++++ diffsynth/models/model_manager.py | 45 +--- diffsynth/models/utils.py | 38 ++++ diffsynth/pipelines/flux_image.py | 166 ++++++++++++-- examples/image_synthesis/flux_controlnet.py | 44 ++++ 9 files changed, 522 insertions(+), 76 deletions(-) create mode 100644 diffsynth/models/flux_controlnet.py create mode 100644 examples/image_synthesis/flux_controlnet.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 27223e9..963af72 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -35,6 +35,7 @@ from ..models.hunyuan_dit import HunyuanDiT from ..models.flux_dit import FluxDiT from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.flux_controlnet import FluxControlNet from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder from ..models.cog_dit import CogDiT @@ -80,6 +81,10 @@ model_loader_configs = [ (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"), + (None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"), + (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"), + (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"), + (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/controlnets/__init__.py b/diffsynth/controlnets/__init__.py index b08ba4c..a3e15ad 100644 --- a/diffsynth/controlnets/__init__.py +++ b/diffsynth/controlnets/__init__.py @@ -1,2 +1,2 @@ -from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager +from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager from .processors import Annotator diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py index f03fec5..fba09b6 100644 --- a/diffsynth/controlnets/controlnet_unit.py +++ b/diffsynth/controlnets/controlnet_unit.py @@ -4,10 +4,11 @@ from .processors import Processor_id class ControlNetConfigUnit: - def __init__(self, processor_id: Processor_id, model_path, scale=1.0): + def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False): self.processor_id = processor_id self.model_path = model_path self.scale = scale + self.skip_processor = skip_processor class ControlNetUnit: @@ -60,3 +61,29 @@ class MultiControlNetManager: else: res_stack = [i + j for i, j in zip(res_stack, res_stack_)] return res_stack + + +class FluxMultiControlNetManager(MultiControlNetManager): + def __init__(self, controlnet_units=[]): + super().__init__(controlnet_units=controlnet_units) + + def process_image(self, image, processor_id=None): + if processor_id is None: + processed_image = [processor(image) for processor in self.processors] + else: + processed_image = [self.processors[processor_id](image)] + return processed_image + + def __call__(self, conditionings, **kwargs): + res_stack, single_res_stack = None, None + for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): + res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs) + res_stack_ = [res * scale for res in res_stack_] + single_res_stack_ = [res * scale for res in single_res_stack_] + if res_stack is None: + res_stack = res_stack_ + single_res_stack = single_res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)] + return res_stack, single_res_stack diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py index 1d23c73..71e47da 100644 --- a/diffsynth/controlnets/processors.py +++ b/diffsynth/controlnets/processors.py @@ -3,37 +3,42 @@ import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") from controlnet_aux.processor import ( - CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector + CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector ) Processor_id: TypeAlias = Literal[ - "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" + "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint" ] class Annotator: - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'): - if processor_id == "canny": - self.processor = CannyDetector() - elif processor_id == "depth": - self.processor = MidasDetector.from_pretrained(model_path).to(device) - elif processor_id == "softedge": - self.processor = HEDdetector.from_pretrained(model_path).to(device) - elif processor_id == "lineart": - self.processor = LineartDetector.from_pretrained(model_path).to(device) - elif processor_id == "lineart_anime": - self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) - elif processor_id == "openpose": - self.processor = OpenposeDetector.from_pretrained(model_path).to(device) - elif processor_id == "tile": - self.processor = None + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False): + if not skip_processor: + if processor_id == "canny": + self.processor = CannyDetector() + elif processor_id == "depth": + self.processor = MidasDetector.from_pretrained(model_path).to(device) + elif processor_id == "softedge": + self.processor = HEDdetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart": + self.processor = LineartDetector.from_pretrained(model_path).to(device) + elif processor_id == "lineart_anime": + self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) + elif processor_id == "openpose": + self.processor = OpenposeDetector.from_pretrained(model_path).to(device) + elif processor_id == "normal": + self.processor = NormalBaeDetector.from_pretrained(model_path).to(device) + elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint": + self.processor = None + else: + raise ValueError(f"Unsupported processor_id: {processor_id}") else: - raise ValueError(f"Unsupported processor_id: {processor_id}") + self.processor = None self.processor_id = processor_id self.detect_resolution = detect_resolution - def __call__(self, image): + def __call__(self, image, mask=None): width, height = image.size if self.processor_id == "openpose": kwargs = { diff --git a/diffsynth/models/flux_controlnet.py b/diffsynth/models/flux_controlnet.py new file mode 100644 index 0000000..d6053b1 --- /dev/null +++ b/diffsynth/models/flux_controlnet.py @@ -0,0 +1,226 @@ +import torch +from einops import rearrange, repeat +from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock +from .utils import hash_state_dict_keys + + + +class FluxControlNet(torch.nn.Module): + def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0): + super().__init__() + self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) + self.time_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) + self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.x_embedder = torch.nn.Linear(64, 3072) + + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)]) + self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)]) + + self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)]) + self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)]) + + self.mode_dict = mode_dict + self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None + self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072) + + + def prepare_image_ids(self, latents): + batch_size, _, height, width = latents.shape + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) + + return latent_image_ids + + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + + def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states): + if len(res_stack) == 0: + return [torch.zeros_like(hidden_states)] * num_blocks + interval = (num_blocks + len(res_stack) - 1) // len(res_stack) + aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)] + return aligned_res_stack + + + def forward( + self, + hidden_states, + controlnet_conditioning, + timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, + processor_id=None, + tiled=False, tile_size=128, tile_stride=64, + **kwargs + ): + if image_ids is None: + image_ids = self.prepare_image_ids(hidden_states) + + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) + if self.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) + prompt_emb = self.context_embedder(prompt_emb) + if self.controlnet_mode_embedder is not None: # Different from FluxDiT + processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int) + processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device) + prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1) + text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + + hidden_states = self.patchify(hidden_states) + hidden_states = self.x_embedder(hidden_states) + controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT + hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT + + controlnet_res_stack = [] + for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_res_stack.append(controlnet_block(hidden_states)) + + controlnet_single_res_stack = [] + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:])) + + controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:]) + controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:]) + + return controlnet_res_stack, controlnet_single_res_stack + + + @staticmethod + def state_dict_converter(): + return FluxControlNetStateDictConverter() + + + +class FluxControlNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + hash_value = hash_state_dict_keys(state_dict) + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + if hash_value == "78d18b9101345ff695f312e7e62538c0": + extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}} + elif hash_value == "b001c89139b5f053c715fe772362dd2a": + extra_kwargs = {"num_single_blocks": 0} + elif hash_value == "52357cb26250681367488a8954c271e8": + extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4} + elif hash_value == "0cfd1740758423a2a854d67c136d1e8c": + extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1} + else: + extra_kwargs = {} + return state_dict_, extra_kwargs + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 9a5c4f1..c68bdde 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -1,7 +1,4 @@ -import os, torch, hashlib, json, importlib -from safetensors import safe_open -from torch import Tensor -from typing_extensions import Literal, TypeAlias +import os, torch, json, importlib from typing import List from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website @@ -50,45 +47,7 @@ from ..extensions.RIFE import IFNet from ..extensions.ESRGAN import RRDBNet from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs -from .utils import load_state_dict, init_weights_on_device - - - -def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): - keys = [] - for key, value in state_dict.items(): - if isinstance(key, str): - if isinstance(value, Tensor): - if with_shape: - shape = "_".join(map(str, list(value.shape))) - keys.append(key + ":" + shape) - keys.append(key) - elif isinstance(value, dict): - keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) - keys.sort() - keys_str = ",".join(keys) - return keys_str - - -def split_state_dict_with_prefix(state_dict): - keys = sorted([key for key in state_dict if isinstance(key, str)]) - prefix_dict = {} - for key in keys: - prefix = key if "." not in key else key.split(".")[0] - if prefix not in prefix_dict: - prefix_dict[prefix] = [] - prefix_dict[prefix].append(key) - state_dicts = [] - for prefix, keys in prefix_dict.items(): - sub_state_dict = {key: state_dict[key] for key in keys} - state_dicts.append(sub_state_dict) - return state_dicts - - -def hash_state_dict_keys(state_dict, with_shape=True): - keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) - keys_str = keys_str.encode(encoding="UTF-8") - return hashlib.md5(keys_str).hexdigest() +from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device): diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py index bd579e4..e18e2dd 100644 --- a/diffsynth/models/utils.py +++ b/diffsynth/models/utils.py @@ -1,6 +1,7 @@ import torch, os from safetensors import safe_open from contextlib import contextmanager +import hashlib @contextmanager def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): @@ -142,3 +143,40 @@ def search_for_files(folder, extensions): files.append(folder) break return files + + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def split_state_dict_with_prefix(state_dict): + keys = sorted([key for key in state_dict if isinstance(key, str)]) + prefix_dict = {} + for key in keys: + prefix = key if "." not in key else key.split(".")[0] + if prefix not in prefix_dict: + prefix_dict[prefix] = [] + prefix_dict[prefix].append(key) + state_dicts = [] + for prefix, keys in prefix_dict.items(): + sub_state_dict = {key: state_dict[key] for key in keys} + state_dicts.append(sub_state_dict) + return state_dicts + + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() \ No newline at end of file diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 06f5649..176651e 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -1,9 +1,13 @@ from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder +from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..prompters import FluxPrompter from ..schedulers import FlowMatchScheduler from .base import BasePipeline +from typing import List import torch from tqdm import tqdm +import numpy as np +from PIL import Image @@ -19,14 +23,15 @@ class FluxImagePipeline(BasePipeline): self.dit: FluxDiT = None self.vae_decoder: FluxVAEDecoder = None self.vae_encoder: FluxVAEEncoder = None - self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder'] + self.controlnet: FluxMultiControlNetManager = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet'] def denoising_model(self): return self.dit - def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]): + def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]): self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1") self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2") self.dit = model_manager.fetch_model("flux_dit") @@ -36,14 +41,25 @@ class FluxImagePipeline(BasePipeline): self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes) + # ControlNets + controlnet_units = [] + for config in controlnet_config_units: + controlnet_unit = ControlNetUnit( + Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor), + model_manager.fetch_model("flux_controlnet", config.model_path), + config.scale + ) + controlnet_units.append(controlnet_unit) + self.controlnet = FluxMultiControlNetManager(controlnet_units) + @staticmethod - def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None): + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): pipe = FluxImagePipeline( device=model_manager.device if device is None else device, torch_dtype=model_manager.torch_dtype, ) - pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes) + pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes) return pipe @@ -71,17 +87,61 @@ class FluxImagePipeline(BasePipeline): return {"image_ids": latent_image_ids, "guidance": guidance} + def apply_controlnet_mask_on_latents(self, latents, mask): + mask = (self.preprocess_image(mask) + 1) / 2 + mask = mask.mean(dim=1, keepdim=True) + mask = mask.to(dtype=self.torch_dtype, device=self.device) + mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:]) + latents = torch.concat([latents, mask], dim=1) + return latents + + + def apply_controlnet_mask_on_image(self, image, mask): + mask = mask.resize(image.size) + mask = self.preprocess_image(mask).mean(dim=[0, 1]) + image = np.array(image) + image[mask > 0] = 0 + image = Image.fromarray(image) + return image + + + def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs): + if isinstance(controlnet_image, Image.Image): + controlnet_image = [controlnet_image] * len(self.controlnet.processors) + + controlnet_frames = [] + for i in range(len(self.controlnet.processors)): + # image annotator + image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0] + if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint": + image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask) + + # image to tensor + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + + # vae encoder + image = self.encode_image(image, **tiler_kwargs) + if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint": + image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask) + + # store it + controlnet_frames.append(image) + return controlnet_frames + + @torch.no_grad() def __call__( self, prompt, - local_prompts= None, - masks= None, - mask_scales= None, + local_prompts=None, + masks=None, + mask_scales=None, negative_prompt="", cfg_scale=1.0, embedded_guidance=3.5, input_image=None, + controlnet_image=None, + controlnet_inpaint_mask=None, denoising_strength=1.0, height=1024, width=1024, @@ -123,19 +183,29 @@ class FluxImagePipeline(BasePipeline): # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) + # Prepare ControlNets + if controlnet_image is not None: + controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} + else: + controlnet_kwargs = {"controlnet_frames": None} + # Denoise - self.load_models_to_device(['dit']) + self.load_models_to_device(['dit', 'controlnet']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - inference_callback = lambda prompt_emb_posi: self.dit( - latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input + inference_callback = lambda prompt_emb_posi: lets_dance_flux( + dit=self.dit, controlnet=self.controlnet, + hidden_states=latents, timestep=timestep, + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs ) noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) if cfg_scale != 1.0: - noise_pred_nega = self.dit( - latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input + noise_pred_nega = lets_dance_flux( + dit=self.dit, controlnet=self.controlnet, + hidden_states=latents, timestep=timestep, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -155,3 +225,75 @@ class FluxImagePipeline(BasePipeline): # Offload all models self.load_models_to_device([]) return image + + + +def lets_dance_flux( + dit: FluxDiT, + controlnet: FluxMultiControlNetManager = None, + hidden_states=None, + timestep=None, + prompt_emb=None, + pooled_prompt_emb=None, + guidance=None, + text_ids=None, + image_ids=None, + controlnet_frames=None, + tiled=False, + tile_size=128, + tile_stride=64, + **kwargs +): + # ControlNet + if controlnet is not None and controlnet_frames is not None: + controlnet_extra_kwargs = { + "hidden_states": hidden_states, + "timestep": timestep, + "prompt_emb": prompt_emb, + "pooled_prompt_emb": pooled_prompt_emb, + "guidance": guidance, + "text_ids": text_ids, + "image_ids": image_ids, + "tiled": tiled, + "tile_size": tile_size, + "tile_stride": tile_stride, + } + controlnet_res_stack, controlnet_single_res_stack = controlnet( + controlnet_frames, **controlnet_extra_kwargs + ) + + if image_ids is None: + image_ids = dit.prepare_image_ids(hidden_states) + + conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb) + if dit.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + + height, width = hidden_states.shape[-2:] + hidden_states = dit.patchify(hidden_states) + hidden_states = dit.x_embedder(hidden_states) + + # Joint Blocks + for block_id, block in enumerate(dit.blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + # ControlNet + if controlnet is not None and controlnet_frames is not None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + + # Single Blocks + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + for block_id, block in enumerate(dit.single_blocks): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + # ControlNet + if controlnet is not None and controlnet_frames is not None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + + hidden_states = dit.final_norm_out(hidden_states, conditioning) + hidden_states = dit.final_proj_out(hidden_states) + hidden_states = dit.unpatchify(hidden_states, height, width) + + return hidden_states diff --git a/examples/image_synthesis/flux_controlnet.py b/examples/image_synthesis/flux_controlnet.py new file mode 100644 index 0000000..80be320 --- /dev/null +++ b/examples/image_synthesis/flux_controlnet.py @@ -0,0 +1,44 @@ +from diffsynth.models.flux_controlnet import FluxControlNet +from diffsynth import load_state_dict, ModelManager, FluxImagePipeline, hash_state_dict_keys, ControlNetConfigUnit +import torch +from PIL import Image +import numpy as np + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev"]) +model_manager.load_models([ + "models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + "models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors", + "models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", + "models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", + "models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors", + "models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + "models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors", + "models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors" +]) +pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", scale=0.3), + ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors", scale=0.1), + ControlNetConfigUnit(processor_id="normal", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", scale=0.1), + ControlNetConfigUnit(processor_id="tile", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", scale=0.05), + ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors", scale=0.01), + ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", scale=0.01), + ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors", scale=0.05), + ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors", scale=0.3), +]) + +torch.manual_seed(0) + +control_image = Image.open("controlnet_input.jpeg").resize((768, 1024)) +control_mask = Image.open("controlnet_mask.jpg").resize((768, 1024)) + +prompt = "masterpiece, best quality, a beautiful girl, CG, blue sky, long red hair, black clothes" +negative_prompt = "oil painting, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + +image = pipe( + prompt=prompt, negative_prompt=negative_prompt, + embedded_guidance=3.5, num_inference_steps=50, + height=1024, width=768, + controlnet_image=control_image, controlnet_inpaint_mask=control_mask, +) +image.save("image.jpg") From aa054db1c799269916fc964628f6d1070879738d Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 23 Oct 2024 14:24:41 +0800 Subject: [PATCH 2/6] bug fix --- diffsynth/pipelines/dancer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py index b7746d3..593b57c 100644 --- a/diffsynth/pipelines/dancer.py +++ b/diffsynth/pipelines/dancer.py @@ -139,6 +139,8 @@ def lets_dance_xl( # 0. Text embedding alignment (only for video processing) if encoder_hidden_states.shape[0] != sample.shape[0]: encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1) + if add_text_embeds.shape[0] != sample.shape[0]: + add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1) # 1. ControlNet controlnet_insert_block_id = 22 @@ -204,7 +206,7 @@ def lets_dance_xl( batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) hidden_states, _, _, _ = block( hidden_states_input[batch_id: batch_id_], - time_emb, + time_emb[batch_id: batch_id_], text_emb[batch_id: batch_id_], res_stack, cross_frame_attention=cross_frame_attention, From 105fe3961ca772eb9344a23253198eaf6f811f94 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 24 Oct 2024 15:42:46 +0800 Subject: [PATCH 3/6] update examples --- diffsynth/configs/model_config.py | 56 +++- diffsynth/models/tiler.py | 54 ++++ diffsynth/pipelines/base.py | 9 +- diffsynth/pipelines/flux_image.py | 45 ++- examples/ControlNet/flux_controlnet.py | 299 ++++++++++++++++++++ examples/image_synthesis/flux_controlnet.py | 44 --- 6 files changed, 455 insertions(+), 52 deletions(-) create mode 100644 examples/ControlNet/flux_controlnet.py delete mode 100644 examples/image_synthesis/flux_controlnet.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 963af72..55ab380 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -204,7 +204,6 @@ preset_models_on_huggingface = { ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ], - # Translator "opus-mt-zh-en": [ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), @@ -346,6 +345,24 @@ preset_models_on_modelscope = { ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"), ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators") ], + "Annotators:Depth": [ + ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"), + ], + "Annotators:Softedge": [ + ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators"), + ], + "Annotators:Lineart": [ + ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"), + ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators"), + ], + "Annotators:Normal": [ + ("lllyasviel/Annotators", "scannet.pt", "models/Annotators"), + ], + "Annotators:Openpose": [ + ("lllyasviel/Annotators", "body_pose_model.pth", "models/Annotators"), + ("lllyasviel/Annotators", "facenet.pth", "models/Annotators"), + ("lllyasviel/Annotators", "hand_pose_model.pth", "models/Annotators"), + ], # AnimateDiff "AnimateDiff_v2": [ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"), @@ -487,6 +504,30 @@ preset_models_on_modelscope = { "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors" ], }, + "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [ + ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"), + ], + "jasperai/Flux.1-dev-Controlnet-Depth": [ + ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"), + ], + "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [ + ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"), + ], + "jasperai/Flux.1-dev-Controlnet-Upscaler": [ + ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"), + ], + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [ + ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"), + ], + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [ + ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"), + ], + "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [ + ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"), + ], + "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [ + ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"), + ], # ESRGAN "ESRGAN_x4": [ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), @@ -546,10 +587,23 @@ Preset_model_id: TypeAlias = Literal[ "ControlNet_union_sdxl_promax", "FLUX.1-dev", "FLUX.1-schnell", + "InstantX/FLUX.1-dev-Controlnet-Union-alpha", + "jasperai/Flux.1-dev-Controlnet-Depth", + "jasperai/Flux.1-dev-Controlnet-Surface-Normals", + "jasperai/Flux.1-dev-Controlnet-Upscaler", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", + "Shakker-Labs/FLUX.1-dev-ControlNet-Depth", + "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "QwenPrompt", "OmostPrompt", "ESRGAN_x4", "RIFE", "CogVideoX-5B", + "Annotators:Depth", + "Annotators:Softedge", + "Annotators:Lineart", + "Annotators:Normal", + "Annotators:Openpose", ] diff --git a/diffsynth/models/tiler.py b/diffsynth/models/tiler.py index 6f36cdf..77c443b 100644 --- a/diffsynth/models/tiler.py +++ b/diffsynth/models/tiler.py @@ -107,6 +107,60 @@ class TileWorker: +class FastTileWorker: + def __init__(self): + pass + + + def build_mask(self, data, is_bound): + _, _, H, W = data.shape + h = repeat(torch.arange(H), "H -> H W", H=H, W=W) + w = repeat(torch.arange(W), "W -> H W", H=H, W=W) + border_width = (H + W) // 4 + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else h + 1, + pad if is_bound[1] else H - h, + pad if is_bound[2] else w + 1, + pad if is_bound[3] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "H W -> 1 H W") + return mask + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + B, C, H, W = model_input.shape + border_width = int(tile_stride*0.5) if border_width is None else border_width + weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device) + values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride): + for w in range(0, W, tile_stride): + if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W): + continue + h_, w_ = h + tile_size, w + tile_size + if h_ > H: h, h_ = H - tile_size, H + if w_ > W: w, w_ = W - tile_size, W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in tasks: + # Forward + hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device) + + mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W)) + values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask + weight[:, :, hl:hr, wl:wr] += mask + values /= weight + return values + + + class TileWorker2Dto3D: """ Process 3D tensors, but only enable TileWorker on 2D. diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 55cfc14..f8f7178 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -47,9 +47,12 @@ class BasePipeline(torch.nn.Module): return value - def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback): - noise_pred_global = inference_callback(prompt_emb_global) - noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs={}, special_local_kwargs_list=None): + noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) + if special_local_kwargs_list is None: + noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] + else: + noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)] noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) return noise_pred diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 176651e..89d730f 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -8,6 +8,7 @@ import torch from tqdm import tqdm import numpy as np from PIL import Image +from ..models.tiler import FastTileWorker @@ -142,6 +143,7 @@ class FluxImagePipeline(BasePipeline): input_image=None, controlnet_image=None, controlnet_inpaint_mask=None, + enable_controlnet_on_negative=False, denoising_strength=1.0, height=1024, width=1024, @@ -186,8 +188,13 @@ class FluxImagePipeline(BasePipeline): # Prepare ControlNets if controlnet_image is not None: controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} + if len(masks) > 0 and controlnet_inpaint_mask is not None: + print("The controlnet_inpaint_mask will be overridden by masks.") + local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks] + else: + local_controlnet_kwargs = None else: - controlnet_kwargs = {"controlnet_frames": None} + controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks) # Denoise self.load_models_to_device(['dit', 'controlnet']) @@ -195,17 +202,21 @@ class FluxImagePipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - inference_callback = lambda prompt_emb_posi: lets_dance_flux( + inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs ) - noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) + noise_pred_posi = self.control_noise_via_local_prompts( + prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, + special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs + ) if cfg_scale != 1.0: + negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs + **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -244,6 +255,32 @@ def lets_dance_flux( tile_stride=64, **kwargs ): + if tiled: + def flux_forward_fn(hl, hr, wl, wr): + return lets_dance_flux( + dit=dit, + controlnet=controlnet, + hidden_states=hidden_states[:, :, hl: hr, wl: wr], + timestep=timestep, + prompt_emb=prompt_emb, + pooled_prompt_emb=pooled_prompt_emb, + guidance=guidance, + text_ids=text_ids, + image_ids=None, + controlnet_frames=[f[:, :, hl: hr, wl: wr] for f in controlnet_frames], + tiled=False, + **kwargs + ) + return FastTileWorker().tiled_forward( + flux_forward_fn, + hidden_states, + tile_size=tile_size, + tile_stride=tile_stride, + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + + # ControlNet if controlnet is not None and controlnet_frames is not None: controlnet_extra_kwargs = { diff --git a/examples/ControlNet/flux_controlnet.py b/examples/ControlNet/flux_controlnet.py new file mode 100644 index 0000000..6fc7526 --- /dev/null +++ b/examples/ControlNet/flux_controlnet.py @@ -0,0 +1,299 @@ +from diffsynth import ModelManager, FluxImagePipeline, ControlNetConfigUnit, download_models, download_customized_models +import torch +from PIL import Image +import numpy as np + + + +def example_1(): + model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"]) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", + scale=0.7 + ), + ]) + + image_1 = pipe( + prompt="a photo of a cat, highly detailed", + height=768, width=768, + seed=0 + ) + image_1.save("image_1.png") + + image_2 = pipe( + prompt="a photo of a cat, highly detailed", + controlnet_image=image_1.resize((2048, 2048)), + input_image=image_1.resize((2048, 2048)), denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=1 + ) + image_2.save("image_2.png") + + + +def example_2(): + model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"]) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", + scale=0.7 + ), + ]) + + image_1 = pipe( + prompt="a beautiful Chinese girl, delicate skin texture", + height=768, width=768, + seed=2 + ) + image_1.save("image_3.png") + + image_2 = pipe( + prompt="a beautiful Chinese girl, delicate skin texture", + controlnet_image=image_1.resize((2048, 2048)), + input_image=image_1.resize((2048, 2048)), denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=3 + ) + image_2.save("image_4.png") + + +def example_3(): + model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"]) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="canny", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ControlNetConfigUnit( + processor_id="depth", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ]) + + image_1 = pipe( + prompt="a cat is running", + height=1024, width=1024, + seed=4 + ) + image_1.save("image_5.png") + + image_2 = pipe( + prompt="sunshine, a cat is running", + controlnet_image=image_1, + height=1024, width=1024, + seed=5 + ) + image_2.save("image_6.png") + + +def example_4(): + model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"]) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="canny", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ControlNetConfigUnit( + processor_id="depth", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.3 + ), + ]) + + image_1 = pipe( + prompt="a beautiful Asian girl, full body, red dress, summer", + height=1024, width=1024, + seed=6 + ) + image_1.save("image_7.png") + + image_2 = pipe( + prompt="a beautiful Asian girl, full body, red dress, winter", + controlnet_image=image_1, + height=1024, width=1024, + seed=7 + ) + image_2.save("image_8.png") + + + +def example_5(): + model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"]) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="inpaint", + model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + scale=0.9 + ), + ]) + + image_1 = pipe( + prompt="a cat sitting on a chair", + height=1024, width=1024, + seed=8 + ) + image_1.save("image_9.png") + + mask = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask[100:350, 350: -300] = 255 + mask = Image.fromarray(mask) + mask.save("mask_9.png") + + image_2 = pipe( + prompt="a cat sitting on a chair, wearing sunglasses", + controlnet_image=image_1, controlnet_inpaint_mask=mask, + height=1024, width=1024, + seed=9 + ) + image_2.save("image_10.png") + + + +def example_6(): + model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[ + "FLUX.1-dev", + "jasperai/Flux.1-dev-Controlnet-Surface-Normals", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta" + ]) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="inpaint", + model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + scale=0.9 + ), + ControlNetConfigUnit( + processor_id="normal", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", + scale=0.6 + ), + ]) + + image_1 = pipe( + prompt="a beautiful Asian woman looking at the sky, wearing a blue t-shirt.", + height=1024, width=1024, + seed=10 + ) + image_1.save("image_11.png") + + mask = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask[-400:, 10:-40] = 255 + mask = Image.fromarray(mask) + mask.save("mask_11.png") + + image_2 = pipe( + prompt="a beautiful Asian woman looking at the sky, wearing a yellow t-shirt.", + controlnet_image=image_1, controlnet_inpaint_mask=mask, + height=1024, width=1024, + seed=11 + ) + image_2.save("image_12.png") + + +def example_7(): + model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[ + "FLUX.1-dev", + "InstantX/FLUX.1-dev-Controlnet-Union-alpha", + "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", + "jasperai/Flux.1-dev-Controlnet-Upscaler", + ]) + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="inpaint", + model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", + scale=0.9 + ), + ControlNetConfigUnit( + processor_id="canny", + model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", + scale=0.5 + ), + ]) + + image_1 = pipe( + prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.", + height=1024, width=1024, + seed=100 + ) + image_1.save("image_13.png") + + mask_global = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask_global = Image.fromarray(mask_global) + mask_global.save("mask_13_global.png") + + mask_1 = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask_1[300:-100, 30: 450] = 255 + mask_1 = Image.fromarray(mask_1) + mask_1.save("mask_13_1.png") + + mask_2 = np.zeros((1024, 1024, 3), dtype=np.uint8) + mask_2[500:-100, -400:] = 255 + mask_2[-200:-100, -500:-400] = 255 + mask_2 = Image.fromarray(mask_2) + mask_2.save("mask_13_2.png") + + image_2 = pipe( + prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.", + controlnet_image=image_1, controlnet_inpaint_mask=mask_global, + local_prompts=["an orange cat, highly detailed", "a girl wearing a red camisole"], masks=[mask_1, mask_2], mask_scales=[10.0, 10.0], + height=1024, width=1024, + seed=101 + ) + image_2.save("image_14.png") + + model_manager.load_lora("models/lora/FLUX-dev-lora-AntiBlur.safetensors", lora_alpha=2) + image_3 = pipe( + prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. clear background.", + negative_prompt="blur, blurry", + input_image=image_2, denoising_strength=0.7, + height=1024, width=1024, + cfg_scale=2.0, num_inference_steps=50, + seed=102 + ) + image_3.save("image_15.png") + + pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ + ControlNetConfigUnit( + processor_id="tile", + model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", + scale=0.7 + ), + ]) + image_4 = pipe( + prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.", + controlnet_image=image_3.resize((2048, 2048)), + input_image=image_3.resize((2048, 2048)), denoising_strength=0.99, + height=2048, width=2048, tiled=True, + seed=103 + ) + image_4.save("image_16.png") + + image_5 = pipe( + prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.", + controlnet_image=image_4.resize((4096, 4096)), + input_image=image_4.resize((4096, 4096)), denoising_strength=0.99, + height=4096, width=4096, tiled=True, + seed=104 + ) + image_5.save("image_17.png") + + + +download_models(["Annotators:Depth", "Annotators:Normal"]) +download_customized_models( + model_id="LiblibAI/FLUX.1-dev-LoRA-AntiBlur", + origin_file_path="FLUX-dev-lora-AntiBlur.safetensors", + local_dir="models/lora" +) +example_1() +example_2() +example_3() +example_4() +example_5() +example_6() +example_7() diff --git a/examples/image_synthesis/flux_controlnet.py b/examples/image_synthesis/flux_controlnet.py deleted file mode 100644 index 80be320..0000000 --- a/examples/image_synthesis/flux_controlnet.py +++ /dev/null @@ -1,44 +0,0 @@ -from diffsynth.models.flux_controlnet import FluxControlNet -from diffsynth import load_state_dict, ModelManager, FluxImagePipeline, hash_state_dict_keys, ControlNetConfigUnit -import torch -from PIL import Image -import numpy as np - - -model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev"]) -model_manager.load_models([ - "models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", - "models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors", - "models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", - "models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", - "models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors", - "models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", - "models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors", - "models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors" -]) -pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[ - ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/InstantX/FLUX___1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors", scale=0.3), - ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Depth/diffusion_pytorch_model.safetensors", scale=0.1), - ControlNetConfigUnit(processor_id="normal", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors", scale=0.1), - ControlNetConfigUnit(processor_id="tile", model_path="models/ControlNet/jasperai/Flux___1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors", scale=0.05), - ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Alpha/diffusion_pytorch_model.safetensors", scale=0.01), - ControlNetConfigUnit(processor_id="inpaint", model_path="models/ControlNet/alimama-creative/FLUX___1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors", scale=0.01), - ControlNetConfigUnit(processor_id="depth", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Depth/diffusion_pytorch_model.safetensors", scale=0.05), - ControlNetConfigUnit(processor_id="canny", model_path="models/ControlNet/Shakker-Labs/FLUX___1-dev-ControlNet-Union-Pro/diffusion_pytorch_model.safetensors", scale=0.3), -]) - -torch.manual_seed(0) - -control_image = Image.open("controlnet_input.jpeg").resize((768, 1024)) -control_mask = Image.open("controlnet_mask.jpg").resize((768, 1024)) - -prompt = "masterpiece, best quality, a beautiful girl, CG, blue sky, long red hair, black clothes" -negative_prompt = "oil painting, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," - -image = pipe( - prompt=prompt, negative_prompt=negative_prompt, - embedded_guidance=3.5, num_inference_steps=50, - height=1024, width=768, - controlnet_image=control_image, controlnet_inpaint_mask=control_mask, -) -image.save("image.jpg") From 45feef9413941a02b9272f4a12de075c6ab1e7d0 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 24 Oct 2024 16:10:15 +0800 Subject: [PATCH 4/6] update model config --- diffsynth/configs/model_config.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 55ab380..09b6ee4 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -346,22 +346,22 @@ preset_models_on_modelscope = { ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators") ], "Annotators:Depth": [ - ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"), + ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"), ], "Annotators:Softedge": [ - ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators"), + ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"), ], "Annotators:Lineart": [ - ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"), - ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators"), + ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"), + ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"), ], "Annotators:Normal": [ - ("lllyasviel/Annotators", "scannet.pt", "models/Annotators"), + ("sd_lora/Annotators", "scannet.pt", "models/Annotators"), ], "Annotators:Openpose": [ - ("lllyasviel/Annotators", "body_pose_model.pth", "models/Annotators"), - ("lllyasviel/Annotators", "facenet.pth", "models/Annotators"), - ("lllyasviel/Annotators", "hand_pose_model.pth", "models/Annotators"), + ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"), + ("sd_lora/Annotators", "facenet.pth", "models/Annotators"), + ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"), ], # AnimateDiff "AnimateDiff_v2": [ From a6d6553ceea678e58b9dbfafd3d0023cda12d392 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 24 Oct 2024 17:36:22 +0800 Subject: [PATCH 5/6] bug fix --- diffsynth/pipelines/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index f8f7178..b968bb6 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -47,8 +47,11 @@ class BasePipeline(torch.nn.Module): return value - def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs={}, special_local_kwargs_list=None): - noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None): + if special_kwargs is None: + noise_pred_global = inference_callback(prompt_emb_global) + else: + noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) if special_local_kwargs_list is None: noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] else: From 2edc485ec1bd1092205a0a5642cd936dcef86129 Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Fri, 25 Oct 2024 00:16:11 +0800 Subject: [PATCH 6/6] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9af7c82..df207cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch>=2.0.0 cupy-cuda12x -transformers +transformers==4.44.1 controlnet-aux==0.0.7 imageio imageio[ffmpeg]