From 6f79fd6d775218d2803633ccabe7da896a068022 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 1 Aug 2024 10:01:39 +0800 Subject: [PATCH 1/3] support sdxl controlnet union --- diffsynth/configs/model_config.py | 2 + diffsynth/controlnets/controlnet_unit.py | 9 +- diffsynth/controlnets/processors.py | 1 + diffsynth/models/model_manager.py | 1 + diffsynth/models/sd_controlnet.py | 1 + diffsynth/models/sdxl_controlnet.py | 318 +++++++++++++++++++++++ diffsynth/models/sdxl_unet.py | 2 + diffsynth/pipelines/dancer.py | 62 ++++- diffsynth/pipelines/sd3_image.py | 3 +- diffsynth/pipelines/sdxl_image.py | 26 +- 10 files changed, 408 insertions(+), 17 deletions(-) create mode 100644 diffsynth/models/sdxl_controlnet.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index df23db5..b15feed 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -16,6 +16,7 @@ from ..models.sd3_vae_decoder import SD3VAEDecoder from ..models.sd3_vae_encoder import SD3VAEEncoder from ..models.sd_controlnet import SDControlNet +from ..models.sdxl_controlnet import SDXLControlNetUnion from ..models.sd_motion import SDMotionModel from ..models.sdxl_motion import SDXLMotionModel @@ -60,6 +61,7 @@ model_loader_configs = [ (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"), (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"), (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"), + (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py index 9754bae..63129d4 100644 --- a/diffsynth/controlnets/controlnet_unit.py +++ b/diffsynth/controlnets/controlnet_unit.py @@ -37,13 +37,14 @@ class MultiControlNetManager: def __call__( self, sample, timestep, encoder_hidden_states, conditionings, - tiled=False, tile_size=64, tile_stride=32 + tiled=False, tile_size=64, tile_stride=32, **kwargs ): res_stack = None - for conditioning, model, scale in zip(conditionings, self.models, self.scales): + for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales): res_stack_ = model( - sample, timestep, encoder_hidden_states, conditioning, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + sample, timestep, encoder_hidden_states, conditioning, **kwargs, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + processor_id=processor.processor_id ) res_stack_ = [res * scale for res in res_stack_] if res_stack is None: diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py index 1d23c73..ce16b5d 100644 --- a/diffsynth/controlnets/processors.py +++ b/diffsynth/controlnets/processors.py @@ -47,5 +47,6 @@ class Annotator: detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) image = image.resize((width, height)) + image.save("/mnt/zhongjie/DiffSynth-Studio-kolors/DiffSynth-Studio/input.jpg") return image diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 1cc8191..b304384 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -23,6 +23,7 @@ from .sd3_vae_decoder import SD3VAEDecoder from .sd3_vae_encoder import SD3VAEEncoder from .sd_controlnet import SDControlNet +from .sdxl_controlnet import SDXLControlNetUnion from .sd_motion import SDMotionModel from .sdxl_motion import SDXLMotionModel diff --git a/diffsynth/models/sd_controlnet.py b/diffsynth/models/sd_controlnet.py index 8c792eb..910e0db 100644 --- a/diffsynth/models/sd_controlnet.py +++ b/diffsynth/models/sd_controlnet.py @@ -97,6 +97,7 @@ class SDControlNet(torch.nn.Module): self, sample, timestep, encoder_hidden_states, conditioning, tiled=False, tile_size=64, tile_stride=32, + **kwargs ): # 1. time time_emb = self.time_proj(timestep).to(sample.dtype) diff --git a/diffsynth/models/sdxl_controlnet.py b/diffsynth/models/sdxl_controlnet.py new file mode 100644 index 0000000..acddf1c --- /dev/null +++ b/diffsynth/models/sdxl_controlnet.py @@ -0,0 +1,318 @@ +import torch +from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler +from .sdxl_unet import SDXLUNet +from .tiler import TileWorker +from .sd_controlnet import ControlNetConditioningLayer +from collections import OrderedDict + + + +class QuickGELU(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + + +class ResidualAttentionBlock(torch.nn.Module): + + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = torch.nn.MultiheadAttention(d_model, n_head) + self.ln_1 = torch.nn.LayerNorm(d_model) + self.mlp = torch.nn.Sequential(OrderedDict([ + ("c_fc", torch.nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", torch.nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = torch.nn.LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + + +class SDXLControlNetUnion(torch.nn.Module): + def __init__(self, global_pool=False): + super().__init__() + self.time_proj = Timesteps(320) + self.time_embedding = torch.nn.Sequential( + torch.nn.Linear(320, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.add_time_proj = Timesteps(256) + self.add_time_embedding = torch.nn.Sequential( + torch.nn.Linear(2816, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.control_type_proj = Timesteps(256) + self.control_type_embedding = torch.nn.Sequential( + torch.nn.Linear(256 * 8, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) + + self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320)) + self.controlnet_transformer = ResidualAttentionBlock(320, 8) + self.task_embedding = torch.nn.Parameter(torch.randn(8, 320)) + self.spatial_ch_projs = torch.nn.Linear(320, 320) + + self.blocks = torch.nn.ModuleList([ + # DownBlock2D + ResnetBlock(320, 320, 1280), + PushBlock(), + ResnetBlock(320, 320, 1280), + PushBlock(), + DownSampler(320), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(320, 640, 1280), + AttentionBlock(10, 64, 640, 2, 2048), + PushBlock(), + ResnetBlock(640, 640, 1280), + AttentionBlock(10, 64, 640, 2, 2048), + PushBlock(), + DownSampler(640), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(640, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + PushBlock(), + # UNetMidBlock2DCrossAttn + ResnetBlock(1280, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + ResnetBlock(1280, 1280, 1280), + PushBlock() + ]) + + self.controlnet_blocks = torch.nn.ModuleList([ + torch.nn.Conv2d(320, 320, kernel_size=(1, 1)), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1)), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1)), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1)), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1)), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1)), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1)), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)), + ]) + + self.global_pool = global_pool + + # 0 -- openpose + # 1 -- depth + # 2 -- hed/pidi/scribble/ted + # 3 -- canny/lineart/anime_lineart/mlsd + # 4 -- normal + # 5 -- segment + # 6 -- tile + # 7 -- repaint + self.task_id = { + "openpose": 0, + "depth": 1, + "softedge": 2, + "canny": 3, + "lineart": 3, + "lineart_anime": 3, + "tile": 6, + "inpaint": 7 + } + + + def fuse_condition_to_input(self, hidden_states, task_id, conditioning): + controlnet_cond = self.controlnet_conv_in(conditioning) + feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) + feat_seq = feat_seq + self.task_embedding[task_id] + x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1) + x = self.controlnet_transformer(x) + + alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1) + controlnet_cond_fuser = controlnet_cond + alpha + + hidden_states = hidden_states + controlnet_cond_fuser + return hidden_states + + + def forward( + self, + sample, timestep, encoder_hidden_states, + conditioning, processor_id, add_time_id, add_text_embeds, + tiled=False, tile_size=64, tile_stride=32, + unet:SDXLUNet=None, + **kwargs + ): + task_id = self.task_id[processor_id] + + # 1. time + t_emb = self.time_proj(timestep).to(sample.dtype) + t_emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(add_time_id) + time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1)) + add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(sample.dtype) + if unet is not None and unet.is_kolors: + add_embeds = unet.add_time_embedding(add_embeds) + else: + add_embeds = self.add_time_embedding(add_embeds) + + control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device) + control_type[:, task_id] = 1 + control_embeds = self.control_type_proj(control_type.flatten()) + control_embeds = control_embeds.reshape((sample.shape[0], -1)) + control_embeds = control_embeds.to(sample.dtype) + control_embeds = self.control_type_embedding(control_embeds) + time_emb = t_emb + add_embeds + control_embeds + + # 2. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = self.conv_in(sample) + hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning) + text_emb = encoder_hidden_states + if unet is not None and unet.is_kolors: + text_emb = unet.text_intermediate_proj(text_emb) + res_stack = [hidden_states] + + # 3. blocks + for i, block in enumerate(self.blocks): + if tiled and not isinstance(block, PushBlock): + _, _, inter_height, _ = hidden_states.shape + resize_scale = inter_height / height + hidden_states = TileWorker().tiled_forward( + lambda x: block(x, time_emb, text_emb, res_stack)[0], + hidden_states, + int(tile_size * resize_scale), + int(tile_stride * resize_scale), + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + else: + hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack) + + # 4. ControlNet blocks + controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)] + + # pool + if self.global_pool: + controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack] + + return controlnet_res_stack + + @staticmethod + def state_dict_converter(): + return SDXLControlNetUnionStateDictConverter() + + + +class SDXLControlNetUnionStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + # architecture + block_types = [ + "ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock", + "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock", + "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", + "ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock" + ] + + # controlnet_rename_dict + controlnet_rename_dict = { + "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight", + "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias", + "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight", + "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias", + "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight", + "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias", + "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight", + "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias", + "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight", + "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias", + "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight", + "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias", + "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight", + "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias", + "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight", + "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias", + "control_add_embedding.linear_1.weight": "control_type_embedding.0.weight", + "control_add_embedding.linear_1.bias": "control_type_embedding.0.bias", + "control_add_embedding.linear_2.weight": "control_type_embedding.2.weight", + "control_add_embedding.linear_2.bias": "control_type_embedding.2.bias", + } + + # Rename each parameter + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]: + pass + elif name in controlnet_rename_dict: + names = controlnet_rename_dict[name].split(".") + elif names[0] == "controlnet_down_blocks": + names[0] = "controlnet_blocks" + elif names[0] == "controlnet_mid_block": + names = ["controlnet_blocks", "9", names[-1]] + elif names[0] in ["time_embedding", "add_embedding"]: + if names[0] == "add_embedding": + names[0] = "add_time_embedding" + names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]] + elif names[0] == "control_add_embedding": + names[0] = "control_type_embedding" + elif names[0] == "transformer_layes": + names[0] = "controlnet_transformer" + names.pop(1) + elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: + if names[0] == "mid_block": + names.insert(1, "0") + block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] + block_type_with_id = ".".join(names[:4]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:4]) + names = ["blocks", str(block_id[block_type])] + names[4:] + if "ff" in names: + ff_index = names.index("ff") + component = ".".join(names[ff_index:ff_index+3]) + component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] + names = names[:ff_index] + [component] + names[ff_index+3:] + if "to_out" in names: + names.pop(names.index("to_out") + 1) + else: + print(name, state_dict[name].shape) + # raise ValueError(f"Unknown parameters: {name}") + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name, param in state_dict.items(): + if name not in rename_dict: + continue + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) \ No newline at end of file diff --git a/diffsynth/models/sdxl_unet.py b/diffsynth/models/sdxl_unet.py index 84dab7d..9bc63e6 100644 --- a/diffsynth/models/sdxl_unet.py +++ b/diffsynth/models/sdxl_unet.py @@ -83,6 +83,8 @@ class SDXLUNet(torch.nn.Module): self.conv_act = torch.nn.SiLU() self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1) + self.is_kolors = is_kolors + def forward( self, sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds, diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py index c87018b..60400f8 100644 --- a/diffsynth/pipelines/dancer.py +++ b/diffsynth/pipelines/dancer.py @@ -136,6 +136,34 @@ def lets_dance_xl( device = "cuda", vram_limit_level = 0, ): + # 1. ControlNet + controlnet_insert_block_id = 22 + if controlnet is not None and controlnet_frames is not None: + res_stacks = [] + # process controlnet frames with batch + for batch_id in range(0, sample.shape[0], controlnet_batch_size): + batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) + res_stack = controlnet( + sample[batch_id: batch_id_], + timestep, + encoder_hidden_states[batch_id: batch_id_], + controlnet_frames[:, batch_id: batch_id_], + add_time_id=add_time_id, + add_text_embeds=add_text_embeds, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + unet=unet, # for Kolors, some modules in ControlNets will be replaced. + ) + if vram_limit_level >= 1: + res_stack = [res.cpu() for res in res_stack] + res_stacks.append(res_stack) + # concat the residual + additional_res_stack = [] + for i in range(len(res_stacks[0])): + res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) + additional_res_stack.append(res) + else: + additional_res_stack = None + # 2. time t_emb = unet.time_proj(timestep).to(sample.dtype) t_emb = unet.time_embedding(t_emb) @@ -156,11 +184,31 @@ def lets_dance_xl( # 4. blocks for block_id, block in enumerate(unet.blocks): - hidden_states, time_emb, text_emb, res_stack = block( - hidden_states, time_emb, text_emb, res_stack, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}) - ) + # 4.1 UNet + if isinstance(block, PushBlock): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].cpu() + elif isinstance(block, PopBlock): + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].to(device) + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + else: + hidden_states_input = hidden_states + hidden_states_output = [] + for batch_id in range(0, sample.shape[0], unet_batch_size): + batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) + hidden_states, _, _, _ = block( + hidden_states_input[batch_id: batch_id_], + time_emb, + text_emb[batch_id: batch_id_], + res_stack, + cross_frame_attention=cross_frame_attention, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}), + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ) + hidden_states_output.append(hidden_states) + hidden_states = torch.concat(hidden_states_output, dim=0) # 4.2 AnimateDiff if motion_modules is not None: if block_id in motion_modules.call_block_id: @@ -169,6 +217,10 @@ def lets_dance_xl( hidden_states, time_emb, text_emb, res_stack, batch_size=1 ) + # 4.3 ControlNet + if block_id == controlnet_insert_block_id and additional_res_stack is not None: + hidden_states += additional_res_stack.pop().to(device) + res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] # 5. output hidden_states = unet.conv_norm_out(hidden_states) diff --git a/diffsynth/pipelines/sd3_image.py b/diffsynth/pipelines/sd3_image.py index ad332d6..f52c2ed 100644 --- a/diffsynth/pipelines/sd3_image.py +++ b/diffsynth/pipelines/sd3_image.py @@ -29,8 +29,7 @@ class SD3ImagePipeline(BasePipeline): def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]): self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1") self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2") - if "sd3_text_encoder_3" in model_manager.model: - self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3") + self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3") self.dit = model_manager.fetch_model("sd3_dit") self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder") self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder") diff --git a/diffsynth/pipelines/sdxl_image.py b/diffsynth/pipelines/sdxl_image.py index 1308d81..c214ebd 100644 --- a/diffsynth/pipelines/sdxl_image.py +++ b/diffsynth/pipelines/sdxl_image.py @@ -25,7 +25,7 @@ class SDXLImagePipeline(BasePipeline): self.unet: SDXLUNet = None self.vae_decoder: SDXLVAEDecoder = None self.vae_encoder: SDXLVAEEncoder = None - # self.controlnet: MultiControlNetManager = None (TODO) + self.controlnet: MultiControlNetManager = None self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None self.ipadapter: SDXLIpAdapter = None @@ -43,7 +43,16 @@ class SDXLImagePipeline(BasePipeline): self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder") self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder") - # ControlNets (TODO) + # ControlNets + controlnet_units = [] + for config in controlnet_config_units: + controlnet_unit = ControlNetUnit( + Annotator(config.processor_id, device=self.device), + model_manager.fetch_model("sdxl_controlnet", config.model_path), + config.scale + ) + controlnet_units.append(controlnet_unit) + self.controlnet = MultiControlNetManager(controlnet_units) # IP-Adapters self.ipadapter = model_manager.fetch_model("sdxl_ipadapter") @@ -150,8 +159,13 @@ class SDXLImagePipeline(BasePipeline): else: ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} - # Prepare ControlNets (TODO) - controlnet_kwargs = {"controlnet_frames": None} + # Prepare ControlNets + if controlnet_image is not None: + controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) + controlnet_image = controlnet_image.unsqueeze(1) + controlnet_kwargs = {"controlnet_frames": controlnet_image} + else: + controlnet_kwargs = {"controlnet_frames": None} # Prepare extra input extra_input = self.prepare_extra_input(latents) @@ -162,14 +176,14 @@ class SDXLImagePipeline(BasePipeline): # Classifier-free guidance noise_pred_posi = lets_dance_xl( - self.unet, motion_modules=None, controlnet=None, + self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **extra_input, **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi, device=self.device, ) if cfg_scale != 1.0: noise_pred_nega = lets_dance_xl( - self.unet, motion_modules=None, controlnet=None, + self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **extra_input, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega, device=self.device, From f189f9f1bea9fe185ef43601fa8675712a639ef4 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 2 Aug 2024 10:31:25 +0800 Subject: [PATCH 2/3] update UI --- diffsynth/pipelines/base.py | 19 +++++++++++++++++ diffsynth/pipelines/hunyuan_image.py | 10 ++++++--- diffsynth/pipelines/sd3_image.py | 7 +++++- diffsynth/pipelines/sd_image.py | 7 +++++- diffsynth/pipelines/sdxl_image.py | 8 ++++++- pages/1_Image_Creator.py | 32 ++++++++++++++++++++++++++++ 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index cb83527..8b99c82 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -31,4 +31,23 @@ class BasePipeline(torch.nn.Module): video = vae_output.cpu().permute(1, 2, 0).numpy() video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] return video + + + def merge_latents(self, value, latents, masks, scales): + height, width = value.shape[-2:] + weight = torch.ones_like(value) + for latent, mask, scale in zip(latents, masks, scales): + mask = self.preprocess_image(mask.resize((height, width))).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, latent.shape[1], 1, 1) + value[mask] += latent[mask] * scale + weight[mask] += scale + value /= weight + return value + + + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback): + noise_pred_global = inference_callback(prompt_emb_global) + noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] + noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) + return noise_pred \ No newline at end of file diff --git a/diffsynth/pipelines/hunyuan_image.py b/diffsynth/pipelines/hunyuan_image.py index 241f772..9181431 100644 --- a/diffsynth/pipelines/hunyuan_image.py +++ b/diffsynth/pipelines/hunyuan_image.py @@ -209,6 +209,9 @@ class HunyuanDiTImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, clip_skip=1, @@ -241,6 +244,7 @@ class HunyuanDiTImagePipeline(BasePipeline): prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) + prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts] # Prepare positional id extra_input = self.prepare_extra_input(latents, tiled, tile_size) @@ -250,9 +254,9 @@ class HunyuanDiTImagePipeline(BasePipeline): timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device) # Positive side - noise_pred_posi = self.dit( - latents, timestep=timestep, **prompt_emb_posi, **extra_input, - ) + inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) + if cfg_scale != 1.0: # Negative side noise_pred_nega = self.dit( diff --git a/diffsynth/pipelines/sd3_image.py b/diffsynth/pipelines/sd3_image.py index f52c2ed..d7dd371 100644 --- a/diffsynth/pipelines/sd3_image.py +++ b/diffsynth/pipelines/sd3_image.py @@ -73,6 +73,9 @@ class SD3ImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, input_image=None, @@ -104,15 +107,17 @@ class SD3ImagePipeline(BasePipeline): # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) + prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts] # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - noise_pred_posi = self.dit( + inference_callback = lambda prompt_emb_posi: self.dit( latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, ) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) noise_pred_nega = self.dit( latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, ) diff --git a/diffsynth/pipelines/sd_image.py b/diffsynth/pipelines/sd_image.py index 0b0d238..016720d 100644 --- a/diffsynth/pipelines/sd_image.py +++ b/diffsynth/pipelines/sd_image.py @@ -90,6 +90,9 @@ class SDImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, clip_skip=1, @@ -125,6 +128,7 @@ class SDImagePipeline(BasePipeline): # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False) + prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts] # IP-Adapter if ipadapter_images is not None: @@ -147,12 +151,13 @@ class SDImagePipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - noise_pred_posi = lets_dance( + inference_callback = lambda prompt_emb_posi: lets_dance( self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi, device=self.device, ) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) noise_pred_nega = lets_dance( self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega, diff --git a/diffsynth/pipelines/sdxl_image.py b/diffsynth/pipelines/sdxl_image.py index c214ebd..2cd73d8 100644 --- a/diffsynth/pipelines/sdxl_image.py +++ b/diffsynth/pipelines/sdxl_image.py @@ -109,6 +109,9 @@ class SDXLImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, clip_skip=1, @@ -146,6 +149,7 @@ class SDXLImagePipeline(BasePipeline): # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False) + prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts] # IP-Adapter if ipadapter_images is not None: @@ -175,12 +179,14 @@ class SDXLImagePipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - noise_pred_posi = lets_dance_xl( + inference_callback = lambda prompt_emb_posi: lets_dance_xl( self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **extra_input, **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi, device=self.device, ) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) + if cfg_scale != 1.0: noise_pred_nega = lets_dance_xl( self.unet, motion_modules=None, controlnet=self.controlnet, diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index 9fb49ca..2d13782 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -255,6 +255,37 @@ with column_input: key="canvas" ) + num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0) + local_prompts, masks, mask_scales = [], [], [] + white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255) + for painter_tab_id in range(num_painter_layer): + with st.expander(f"Painter layer {painter_tab_id}", expanded=True): + enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True) + local_prompt = st.text_area(f"Prompt {painter_tab_id}") + mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0) + stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100) + canvas_result_local = st_canvas( + fill_color="#000000", + stroke_width=stroke_width, + stroke_color="#000000", + background_color="rgba(255, 255, 255, 0)", + background_image=white_board, + update_streamlit=True, + height=512, + width=512, + drawing_mode="freedraw", + key=f"canvas_{painter_tab_id}" + ) + if enable_local_prompt: + local_prompts.append(local_prompt) + if canvas_result_local.image_data is not None: + mask = apply_stroke_to_image(canvas_result_local.image_data, white_board) + else: + mask = white_board + mask = Image.fromarray(255 - np.array(mask)) + masks.append(mask) + mask_scales.append(mask_scale) + with column_output: run_button = st.button("Generate image", type="primary") @@ -282,6 +313,7 @@ with column_output: progress_bar_st = st.progress(0.0) image = pipeline( prompt, negative_prompt=negative_prompt, + local_prompts=local_prompts, masks=masks, mask_scales=mask_scales, cfg_scale=cfg_scale, num_inference_steps=num_inference_steps, height=height, width=width, input_image=input_image, denoising_strength=denoising_strength, From 6877b460c4c6b266986f3fc039598e122330843b Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 2 Aug 2024 13:47:07 +0800 Subject: [PATCH 3/3] fix bugs --- diffsynth/controlnets/processors.py | 1 - examples/train/README.md | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py index ce16b5d..1d23c73 100644 --- a/diffsynth/controlnets/processors.py +++ b/diffsynth/controlnets/processors.py @@ -47,6 +47,5 @@ class Annotator: detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) image = image.resize((width, height)) - image.save("/mnt/zhongjie/DiffSynth-Studio-kolors/DiffSynth-Studio/input.jpg") return image diff --git a/examples/train/README.md b/examples/train/README.md index bbc20b7..354dd90 100644 --- a/examples/train/README.md +++ b/examples/train/README.md @@ -13,6 +13,12 @@ Image Examples of fine-tuned LoRA. The prompt is "一只小狗蹦蹦跳跳,周 |Without LoRA|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/9d79ed7a-e8cf-4d98-800a-f182809db318)|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/ddb834a5-6366-412b-93dc-6d957230d66e)|![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)| |With LoRA|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/02f62323-6ee5-4788-97a1-549732dbe4f0)|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/8e7b2888-d874-4da4-a75b-11b6b214b9bf)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)| +## Install additional packages + +``` +pip install peft lightning +``` + ## Prepare your dataset We provide an example dataset [here](https://modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/files). You need to manage the training images as follows: