From e3d89cec0c8511e9e1e59508650355e0ce89c20a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 25 Dec 2024 17:19:31 +0800 Subject: [PATCH 1/8] temp commit for entity control --- diffsynth/models/flux_dit.py | 128 +++++++++++++----- diffsynth/pipelines/flux_image.py | 93 +++++++++++-- examples/EntityControl/entity_control_flux.py | 54 ++++++++ examples/EntityControl/entity_inpaint_flux.py | 59 ++++++++ examples/EntityControl/utils.py | 59 ++++++++ 5 files changed, 349 insertions(+), 44 deletions(-) create mode 100644 examples/EntityControl/entity_control_flux.py create mode 100644 examples/EntityControl/entity_inpaint_flux.py create mode 100644 examples/EntityControl/utils.py diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index faf58cd..8deea3b 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -40,7 +40,7 @@ class RoPEEmbedding(torch.nn.Module): n_axes = ids.shape[-1] emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) return emb.unsqueeze(1) - + class FluxJointAttention(torch.nn.Module): @@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states_a.shape[0] # Part A @@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] @@ -103,7 +103,7 @@ class FluxJointAttention(torch.nn.Module): else: hidden_states_b = self.b_to_out(hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxJointTransformerBlock(torch.nn.Module): @@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module): ) - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) # Attention - attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list) + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a @@ -147,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module): hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxSingleAttention(torch.nn.Module): @@ -184,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) return hidden_states - + class AdaLayerNormSingle(torch.nn.Module): @@ -200,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module): shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa - + class FluxSingleTransformerBlock(torch.nn.Module): @@ -225,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None): + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states.shape[0] qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) @@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) if ipadapter_kwargs_list is not None: @@ -243,21 +243,21 @@ class FluxSingleTransformerBlock(torch.nn.Module): return hidden_states - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): residual = hidden_states_a norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) hidden_states_a = self.to_qkv_mlp(norm_hidden_states) attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] - attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list) + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) hidden_states_a = residual + hidden_states_a - + return hidden_states_a, hidden_states_b - + class AdaLayerNormContinuous(torch.nn.Module): @@ -300,7 +300,7 @@ class FluxDiT(torch.nn.Module): def unpatchify(self, hidden_states, height, width): hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) return hidden_states - + def prepare_image_ids(self, latents): batch_size, _, height, width = latents.shape @@ -317,7 +317,7 @@ class FluxDiT(torch.nn.Module): latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) return latent_image_ids - + def tiled_forward( self, @@ -337,12 +337,45 @@ class FluxDiT(torch.nn.Module): ) return hidden_states + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask def forward( self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, - tiled=False, tile_size=128, tile_stride=64, + tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None, use_gradient_checkpointing=False, **kwargs ): @@ -353,46 +386,70 @@ class FluxDiT(torch.nn.Module): tile_size=tile_size, tile_stride=tile_stride, **kwargs ) - + if image_ids is None: image_ids = self.prepare_image_ids(hidden_states) - + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) if self.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = self.context_embedder(prompt_emb) - image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = self.patchify(hidden_states) hidden_states = self.x_embedder(hidden_states) - + + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - + for block in self.blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) for block in self.single_blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = hidden_states[:, prompt_emb.shape[1]:] hidden_states = self.final_norm_out(hidden_states, conditioning) @@ -400,7 +457,7 @@ class FluxDiT(torch.nn.Module): hidden_states = self.unpatchify(hidden_states, height, width) return hidden_states - + def quantize(self): def cast_to(weight, dtype=None, device=None, copy=False): @@ -440,16 +497,16 @@ class FluxDiT(torch.nn.Module): class Linear(torch.nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def forward(self,input,**kwargs): weight,bias= cast_bias_weight(self,input) return torch.nn.functional.linear(input,weight,bias) - + class RMSNorm(torch.nn.Module): def __init__(self, module): super().__init__() self.module = module - + def forward(self,hidden_states,**kwargs): weight= cast_weight(self.module,hidden_states) input_dtype = hidden_states.dtype @@ -457,7 +514,7 @@ class FluxDiT(torch.nn.Module): hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) hidden_states = hidden_states.to(input_dtype) * weight return hidden_states - + def replace_layer(model): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): @@ -483,7 +540,6 @@ class FluxDiT(torch.nn.Module): @staticmethod def state_dict_converter(): return FluxDiTStateDictConverter() - class FluxDiTStateDictConverter: @@ -587,7 +643,7 @@ class FluxDiTStateDictConverter: state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) return state_dict_ - + def from_civitai(self, state_dict): rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 42d142c..c012f09 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -139,6 +139,39 @@ class FluxImagePipeline(BasePipeline): images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] return torch.cat(images, dim=0) + def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.): + # inpaint noise + inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id] + # merge noise + weight = torch.ones_like(inpaint_noise) + inpaint_noise[fg_mask] = pred_noise[fg_mask] + inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight + weight[bg_mask] += background_weight + inpaint_noise /= weight + return inpaint_noise + + def preprocess_masks(self, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None): + fg_mask, bg_mask = None, None + if inpaint_input is not None: + from copy import deepcopy + masks_ = deepcopy(entity_masks) + fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_]) + fg_masks = (fg_masks > 0).float() + fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0 + bg_mask = ~fg_mask + entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0) + return entity_prompts, entity_masks, fg_mask, bg_mask + @torch.no_grad() def __call__( self, @@ -160,6 +193,10 @@ class FluxImagePipeline(BasePipeline): width=1024, num_inference_steps=30, t5_sequence_length=512, + inpaint_input=None, + entity_prompts=None, + entity_masks=None, + use_seperated_negtive_prompt=True, tiled=False, tile_size=128, tile_stride=64, @@ -176,12 +213,13 @@ class FluxImagePipeline(BasePipeline): self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if input_image is not None: + if input_image is not None or inpaint_input is not None: + input_image = input_image or inpaint_input self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) - latents = self.encode_image(image, **tiler_kwargs) + input_latents = self.encode_image(image, **tiler_kwargs) noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) - latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) @@ -195,6 +233,14 @@ class FluxImagePipeline(BasePipeline): prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] + # Entity control + negative_entity_prompts = None + negative_masks = None + if entity_masks is not None: + entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input) + if use_seperated_negtive_prompt and cfg_scale != 1.0: + negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1) + negative_masks = entity_masks # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) @@ -229,18 +275,20 @@ class FluxImagePipeline(BasePipeline): # Classifier-free guidance inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, + hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs ) + if inpaint_input: + noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id) if cfg_scale != 1.0: negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, + hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks, **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) @@ -278,6 +326,8 @@ def lets_dance_flux( tiled=False, tile_size=128, tile_stride=64, + entity_prompts=None, + entity_masks=None, ipadapter_kwargs_list={}, **kwargs ): @@ -333,13 +383,38 @@ def lets_dance_flux( if dit.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = dit.context_embedder(prompt_emb) - image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) hidden_states = dit.x_embedder(hidden_states) - + + # Entity Control + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + # Joint Blocks for block_id, block in enumerate(dit.blocks): hidden_states, prompt_emb = block( @@ -347,6 +422,7 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, + attention_mask, ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) # ControlNet if controlnet is not None and controlnet_frames is not None: @@ -361,6 +437,7 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, + attention_mask, ipadapter_kwargs_list=ipadapter_kwargs_list.get( block_id + num_joint_blocks, None)) # ControlNet diff --git a/examples/EntityControl/entity_control_flux.py b/examples/EntityControl/entity_control_flux.py new file mode 100644 index 0000000..5e74e0b --- /dev/null +++ b/examples/EntityControl/entity_control_flux.py @@ -0,0 +1,54 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image + +# lora_path = download_customized_models( +# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", +# origin_file_path="merged_lora.safetensors", +# local_dir="models/lora" +# )[0] + +lora_path = '/root/model_bf16.safetensors' +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", + "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", + "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +model_manager.load_lora(lora_path, lora_alpha=1.) + +pipe = FluxImagePipeline.from_model_manager(model_manager) + +mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' +image_shape = 1024 +guidance = 3.5 +cfg = 3.0 +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +names = ['row_2_1'] +seeds = [0] +# use this to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +for name, seed in zip(names, seeds): + out_dir = f'workdirs/entity_control/{name}' + os.makedirs(out_dir, exist_ok=True) + cur_dir = os.path.join(mask_dir, name) + metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) + for seed in range(3, 10): + prompt = metas['global_prompt'] + mask_prompts = metas['mask_prompts'] + masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] + torch.manual_seed(seed) + image = pipe( + prompt=prompt, + cfg_scale=cfg, + negative_prompt=negative_prompt, + num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape, + entity_prompts=mask_prompts, entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt + ) + use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' + visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/entity_inpaint_flux.py b/examples/EntityControl/entity_inpaint_flux.py new file mode 100644 index 0000000..5027563 --- /dev/null +++ b/examples/EntityControl/entity_inpaint_flux.py @@ -0,0 +1,59 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image + +# lora_path = download_customized_models( +# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", +# origin_file_path="merged_lora.safetensors", +# local_dir="models/lora" +# )[0] + +lora_path = '/root/model_bf16.safetensors' +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", + "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", + "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +model_manager.load_lora(lora_path, lora_alpha=1.) + +pipe = FluxImagePipeline.from_model_manager(model_manager) + +mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' +image_shape = 1024 +guidance = 3.5 +cfg = 3.0 +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +names = ['inpaint2'] +seeds = [0] +use_seperated_negtive_prompt = False +for name, seed in zip(names, seeds): + out_dir = f'workdirs/paper_app/inpaint/elc/{name}' + os.makedirs(out_dir, exist_ok=True) + cur_dir = os.path.join(mask_dir, name) + metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) + inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB') + prompt = metas['global_prompt'] + prompt = 'A person with a dog walking on the cloud. A rocket in the sky' + mask_prompts = metas['mask_prompts'] + masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] + torch.manual_seed(seed) + image = pipe( + prompt=prompt, + cfg_scale=cfg, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=guidance, + height=image_shape, + width=image_shape, + entity_prompts=mask_prompts, + entity_masks=masks, + inpaint_input=inpaint_input, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, + ) + use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' + visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/utils.py b/examples/EntityControl/utils.py new file mode 100644 index 0000000..573c8a6 --- /dev/null +++ b/examples/EntityControl/utils.py @@ -0,0 +1,59 @@ +from PIL import Image, ImageDraw, ImageFont +import random + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result From b6620f3dde192f994b094c71d61119f8f6e9addd Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 14:04:28 +0800 Subject: [PATCH 2/8] update_example entity control --- apps/gradio/entity_level_control.py | 320 ++++++++++++++++++ examples/EntityControl/README.md | 1 + examples/EntityControl/entity_control.py | 57 ++++ examples/EntityControl/entity_control_flux.py | 54 --- .../EntityControl/entity_control_ipadapter.py | 51 +++ examples/EntityControl/entity_inpaint.py | 58 ++++ examples/EntityControl/entity_inpaint_flux.py | 59 ---- 7 files changed, 487 insertions(+), 113 deletions(-) create mode 100644 apps/gradio/entity_level_control.py create mode 100644 examples/EntityControl/README.md create mode 100644 examples/EntityControl/entity_control.py delete mode 100644 examples/EntityControl/entity_control_flux.py create mode 100644 examples/EntityControl/entity_control_ipadapter.py create mode 100644 examples/EntityControl/entity_inpaint.py delete mode 100644 examples/EntityControl/entity_inpaint_flux.py diff --git a/apps/gradio/entity_level_control.py b/apps/gradio/entity_level_control.py new file mode 100644 index 0000000..d914cd0 --- /dev/null +++ b/apps/gradio/entity_level_control.py @@ -0,0 +1,320 @@ +import gradio as gr +from diffsynth import ModelManager, FluxImagePipeline +import os, torch +from PIL import Image +import numpy as np +from PIL import ImageDraw, ImageFont +import random +import json + +lora_checkpoint_path = 'models/lora/entity_control/model_bf16.safetensors' +save_masks_dir = 'workdirs/tmp_mask' + +def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): + save_dir = os.path.join(save_masks_dir, random_dir) + os.makedirs(save_dir, exist_ok=True) + for i, mask in enumerate(masks): + save_path = os.path.join(save_dir, f'{i}.png') + mask.save(save_path) + sample = { + "global_prompt": global_prompt, + "mask_prompts": mask_prompts, + "seed": seed, + } + with open(os.path.join(save_dir, f"prompts.json"), 'w') as f: + json.dump(sample, f, indent=4) + +def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + return result + +config = { + "model_config": { + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "default_parameters": { + "cfg_scale": 3.0, + "embedded_guidance": 3.5, + "num_inference_steps": 30, + } + }, + }, + "max_num_painter_layers": 8, + "max_num_model_cache": 1, +} + + +def load_model_list(model_type): + if model_type is None: + return [] + folder = config["model_config"][model_type]["model_folder"] + file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] + if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: + file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] + file_list = sorted(file_list) + return file_list + + +model_dict = {} + +def load_model(model_type, model_path): + global model_dict + model_key = f"{model_type}:{model_path}" + if model_key in model_dict: + return model_dict[model_key] + model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) + model_manager = ModelManager() + if model_type == "FLUX": + model_manager.torch_dtype = torch.bfloat16 + file_list = [ + os.path.join(model_path, "text_encoder/model.safetensors"), + os.path.join(model_path, "text_encoder_2"), + ] + for file_name in os.listdir(model_path): + if file_name.endswith(".safetensors"): + file_list.append(os.path.join(model_path, file_name)) + model_manager.load_models(file_list) + model_manager.load_lora(lora_checkpoint_path, lora_alpha=1.) + + else: + model_manager.load_model(model_path) + pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) + while len(model_dict) + 1 > config["max_num_model_cache"]: + key = next(iter(model_dict.keys())) + model_manager_to_release, _ = model_dict[key] + model_manager_to_release.to("cpu") + del model_dict[key] + torch.cuda.empty_cache() + model_dict[model_key] = model_manager, pipe + return model_manager, pipe + + +with gr.Blocks() as app: + gr.Markdown(""" + # 实体级控制文生图模型EliGen + **UI说明** + 1. 点击Load model读取模型,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。 + 2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。 + 3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。 + 4. 尽情创造! + """) + gr.Markdown(""" + # Entity-Level Controlled Text-to-Image Model: EliGen + **UI Instructions** + 1. Click "Load model" to load the model. The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. + 2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts. + 3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing. + 4. Enjoy! + """) + with gr.Row(): + random_mask_dir = gr.State('') + with gr.Column(scale=382, min_width=100): + model_type = gr.State('FLUX') + model_path = gr.State('FLUX.1-dev') + with gr.Accordion(label="Model"): + load_model_button = gr.Button(value="Load model") + + with gr.Accordion(label="Global prompt"): + prompt = gr.Textbox(label="Prompt", lines=3) + negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", lines=1) + cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale") + embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale") + + with gr.Accordion(label="Inference Options"): + num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps") + height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") + width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") + return_with_mask = gr.Checkbox(value=True, interactive=True, label="show result with mask painting") + with gr.Column(): + use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed") + seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False) + with gr.Accordion(label="Inpaint Input Image (Testing)"): + input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") + background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight") + + with gr.Column(): + reset_input_button = gr.Button(value="Reset Inpaint Input") + send_input_to_painter = gr.Button(value="Set as painter's background") + @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) + def reset_input_image(input_image): + return None + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask], + outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, load_model_button, random_mask_dir], + triggers=load_model_button.click + ) + def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask): + load_model(model_type, model_path) + cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale) + embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) + num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) + height = config["model_config"][model_type]["default_parameters"].get("height", height) + width = config["model_config"][model_type]["default_parameters"].get("width", width) + return_with_mask = config["model_config"][model_type]["default_parameters"].get("return_with_mask", return_with_mask) + return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, gr.update(value="Loaded FLUX"), gr.State(f'{random.randint(0, 1000000):08d}') + + + with gr.Column(scale=618, min_width=100): + with gr.Accordion(label="Painter"): + enable_local_prompt_list = [] + local_prompt_list = [] + mask_scale_list = [] + canvas_list = [] + for painter_layer_id in range(config["max_num_painter_layers"]): + with gr.Tab(label=f"Layer {painter_layer_id}"): + enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}") + local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") + mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}") + canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA", + brush=gr.Brush(default_size=30, default_color="#000000", colors=["#000000"]), + label="Painter", key=f"canvas_{painter_layer_id}") + @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden") + def resize_canvas(height, width, canvas): + h, w = canvas["background"].shape[:2] + if h != height or width != w: + return np.ones((height, width, 3), dtype=np.uint8) * 255 + else: + return canvas + + enable_local_prompt_list.append(enable_local_prompt) + local_prompt_list.append(local_prompt) + mask_scale_list.append(mask_scale) + canvas_list.append(canvas) + with gr.Accordion(label="Results"): + run_button = gr.Button(value="Generate", variant="primary") + output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") + with gr.Row(): + with gr.Column(): + output_to_painter_button = gr.Button(value="Set as painter's background") + with gr.Column(): + output_to_input_button = gr.Button(value="Set as input image") + real_output = gr.State(None) + mask_out = gr.State(None) + + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list, + outputs=[output_image, real_output, mask_out], + triggers=run_button.click + ) + def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): + print("save to:", random_mask_dir.value) + _, pipe = load_model(model_type, model_path) + input_params = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "cfg_scale": cfg_scale, + "num_inference_steps": num_inference_steps, + "height": height, + "width": width, + "progress_bar_cmd": progress.tqdm, + } + if isinstance(pipe, FluxImagePipeline): + input_params["embedded_guidance"] = embedded_guidance + if input_image is not None: + input_params["inpaint_input"] = input_image.resize((width, height)).convert("RGB") + + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( + args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], + args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], + args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]], + args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]] + ) + local_prompts, masks, mask_scales = [], [], [] + for enable_local_prompt, local_prompt, mask_scale, canvas in zip( + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list + ): + if enable_local_prompt: + local_prompts.append(local_prompt) + masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) + mask_scales.append(mask_scale) + entity_masks = None if len(masks) == 0 else masks + entity_prompts = None if len(local_prompts) == 0 else local_prompts + input_params.update({ + "entity_prompts": entity_prompts, + "entity_masks": entity_masks, + }) + torch.manual_seed(seed) + image = pipe(**input_params) + # visualize masks + masks = [mask.resize(image.size) for mask in masks] + image_with_mask = visualize_masks(image, masks, local_prompts) + # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir.value) + + real_output = gr.State(image) + mask_out = gr.State(image_with_mask) + + if return_with_mask: + return image_with_mask, real_output, mask_out + return image, real_output, mask_out + + @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) + def send_input_to_painter_background(input_image, *canvas_list): + if input_image is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = input_image.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) + def send_output_to_painter_background(real_output, *canvas_list): + if real_output is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = real_output.value.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") + def show_output(return_with_mask, real_output, mask_out): + if return_with_mask: + return mask_out.value + else: + return real_output.value + @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) + def send_output_to_pipe_input(real_output): + return real_output.value + +app.launch() diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md new file mode 100644 index 0000000..784c590 --- /dev/null +++ b/examples/EntityControl/README.md @@ -0,0 +1 @@ +# EliGen: Entity-Level Controlled Image Generation diff --git a/examples/EntityControl/entity_control.py b/examples/EntityControl/entity_control.py new file mode 100644 index 0000000..2735682 --- /dev/null +++ b/examples/EntityControl/entity_control.py @@ -0,0 +1,57 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import requests +from io import BytesIO + +# download and load model +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 4 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/02905f6e-40c2-4482-9abe-b1ce50ccabbf', + 'https://github.com/user-attachments/assets/a4cf4361-abf7-4556-ba94-74683eda4cb7', + 'https://github.com/user-attachments/assets/b6595ff4-7269-4d8f-acf0-5df40bd6c59f', + 'https://github.com/user-attachments/assets/941d39a7-3aa1-437f-8b2a-4adb15d2fb3e', + 'https://github.com/user-attachments/assets/400c4086-5398-4291-b1b5-22d8483c08d9', + 'https://github.com/user-attachments/assets/ce324c77-fa1d-4aad-a5cb-698f0d5eca70', + 'https://github.com/user-attachments/assets/4e62325f-a60c-44f7-b53b-6da0869bb9db' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" + +# generate image +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + height=image_shape, + width=image_shape, + entity_prompts=entity_prompts, + entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, +) +image.save(f"entity_control.png") +visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_control_flux.py b/examples/EntityControl/entity_control_flux.py deleted file mode 100644 index 5e74e0b..0000000 --- a/examples/EntityControl/entity_control_flux.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline -from examples.EntityControl.utils import visualize_masks -import os -import json -from PIL import Image - -# lora_path = download_customized_models( -# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", -# origin_file_path="merged_lora.safetensors", -# local_dir="models/lora" -# )[0] - -lora_path = '/root/model_bf16.safetensors' -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") -model_manager.load_models([ - "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", - "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", - "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" -]) -model_manager.load_lora(lora_path, lora_alpha=1.) - -pipe = FluxImagePipeline.from_model_manager(model_manager) - -mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' -image_shape = 1024 -guidance = 3.5 -cfg = 3.0 -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," -names = ['row_2_1'] -seeds = [0] -# use this to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -for name, seed in zip(names, seeds): - out_dir = f'workdirs/entity_control/{name}' - os.makedirs(out_dir, exist_ok=True) - cur_dir = os.path.join(mask_dir, name) - metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) - for seed in range(3, 10): - prompt = metas['global_prompt'] - mask_prompts = metas['mask_prompts'] - masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] - torch.manual_seed(seed) - image = pipe( - prompt=prompt, - cfg_scale=cfg, - negative_prompt=negative_prompt, - num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape, - entity_prompts=mask_prompts, entity_masks=masks, - use_seperated_negtive_prompt=use_seperated_negtive_prompt - ) - use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' - visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/entity_control_ipadapter.py b/examples/EntityControl/entity_control_ipadapter.py new file mode 100644 index 0000000..76ae869 --- /dev/null +++ b/examples/EntityControl/entity_control_ipadapter.py @@ -0,0 +1,51 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import requests +from io import BytesIO + +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 4 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/e6745b3f-ab2b-4612-9bb5-b7235474a9a4', + 'https://github.com/user-attachments/assets/5ddf9a89-32fa-4540-89ad-e956130942b3', + 'https://github.com/user-attachments/assets/9d8a0bb0-6817-497e-af85-44f2512afe79' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ['A girl', 'hat', 'sunset'] +global_prompt = "A girl wearing a hat, looking at the sunset" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" + +response = requests.get('https://github.com/user-attachments/assets/019bbfaa-04b3-4de6-badb-32b67c29a1bc') +reference_img = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) + +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, embedded_guidance=3.5, height=image_shape, width=image_shape, + entity_prompts=entity_prompts, entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, + ipadapter_images=[reference_img], ipadapter_scale=0.7 +) +image.save(f"styled_entity_control.png") +visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_inpaint.py b/examples/EntityControl/entity_inpaint.py new file mode 100644 index 0000000..eae4641 --- /dev/null +++ b/examples/EntityControl/entity_inpaint.py @@ -0,0 +1,58 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image +import requests +from io import BytesIO + +# download and load model +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 0 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/0cf78663-5314-4280-a065-31ded7a24a46', + 'https://github.com/user-attachments/assets/bd3938b8-72a8-4d56-814f-f6445971b91d' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ["A person wear red shirt", "Airplane"] +global_prompt = "A person walking on the path in front of a house; An airplane in the sky" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur" + +response = requests.get('https://github.com/user-attachments/assets/fa4d6ba5-08fd-4fc7-adbb-19898d839364') +inpaint_input = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) + +# generate image +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + height=image_shape, + width=image_shape, + entity_prompts=entity_prompts, + entity_masks=masks, + inpaint_input=inpaint_input, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, +) +image.save(f"entity_inpaint.png") +visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") \ No newline at end of file diff --git a/examples/EntityControl/entity_inpaint_flux.py b/examples/EntityControl/entity_inpaint_flux.py deleted file mode 100644 index 5027563..0000000 --- a/examples/EntityControl/entity_inpaint_flux.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline -from examples.EntityControl.utils import visualize_masks -import os -import json -from PIL import Image - -# lora_path = download_customized_models( -# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", -# origin_file_path="merged_lora.safetensors", -# local_dir="models/lora" -# )[0] - -lora_path = '/root/model_bf16.safetensors' -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") -model_manager.load_models([ - "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", - "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", - "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" -]) -model_manager.load_lora(lora_path, lora_alpha=1.) - -pipe = FluxImagePipeline.from_model_manager(model_manager) - -mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' -image_shape = 1024 -guidance = 3.5 -cfg = 3.0 -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," -names = ['inpaint2'] -seeds = [0] -use_seperated_negtive_prompt = False -for name, seed in zip(names, seeds): - out_dir = f'workdirs/paper_app/inpaint/elc/{name}' - os.makedirs(out_dir, exist_ok=True) - cur_dir = os.path.join(mask_dir, name) - metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) - inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB') - prompt = metas['global_prompt'] - prompt = 'A person with a dog walking on the cloud. A rocket in the sky' - mask_prompts = metas['mask_prompts'] - masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] - torch.manual_seed(seed) - image = pipe( - prompt=prompt, - cfg_scale=cfg, - negative_prompt=negative_prompt, - num_inference_steps=50, - embedded_guidance=guidance, - height=image_shape, - width=image_shape, - entity_prompts=mask_prompts, - entity_masks=masks, - inpaint_input=inpaint_input, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, - ) - use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' - visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) From c087f68d74abdd106c407f5b8f3786bc34088b53 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 17:08:44 +0800 Subject: [PATCH 3/8] update readme --- examples/EntityControl/README.md | 53 ++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index 784c590..14d716a 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -1 +1,54 @@ # EliGen: Entity-Level Controlled Image Generation + +## Introduction + +We propose EliGen, a novel approach that leverages fine-grained entity-level information to enable precise and controllable text-to-image generation. EliGen excels in tasks such as entity-level controlled image generation and image inpainting, while its applicability is not limited to these areas. Additionally, it can be seamlessly integrated with existing community models, such as the IP-Adpater. + +* Paper: Comming soon +* Github: https://github.com/modelscope/DiffSynth-Studio +* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) +* Training dataset: Coming soon + +## Methodology + + + +We introduce a regional attention mechanism within the DiT framework to effectively process the conditions of each entity. This mechanism enables the local prompt associated with each entity to semantically influence specific regions through regional attention. To further enhance the layout control capabilities of EliGen, we meticulously curate an entity-annotated dataset and fine-tune the model using the LoRA framework. + +1. **Regional Attention**: Regional attention is shown in above figure, which can be easily applied to other text-to-image models. Its core principle involves transforming the positional information of each entity into an attention mask, ensuring that the mechanism only affects the designated regions. + +2. **Dataset with Entity Annotation**: To curate a dedicated entity control dataset, we start by randomly selecting captions from DiffusionDB and generating the corresponding source image using Flux. Next, we employ Qwen2-VL 72B, recognized for its advanced grounding capabilities among MLLMs, to randomly identify entities within the image. These entities are annotated with local prompts and bounding boxes for precise localization, forming the foundation of our dataset for further training. + +3. **Training**: We apply LoRA and deepspeed to finetune regional attention with curated dataset, enabling our EliGen performing effective entity-level control. + +## Usage +1. **Entity-Level Controlled Image Generation** +See [./entity_control.py](./entity_control.py) for usage. +2. **Image Inpainting** + To apply EliGen to image inpainting task, we propose a inpainting fusion pipeline to preserve the non-painting areas while enabling precise, entity-level modifications over inpaining regions. + See [./entity_inpaint.py](./entity_inpaint.py) for usage. +3. **Styled Entity Control** + EliGen can be seamlessly integrated with existing community models. We have provided an example of how to integrate it with the IP-Adpater. See [./entity_control_ipadapter.py](./entity_control_ipadapter.py) for usage. +4. **Play with EliGen using UI** + Download the checkpoint of EliGen from [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) to `models/lora/entity_control` and run the following command to try interactive UI: + ```bash + python apps/gradio/entity_level_control.py + ``` +## Examples +### Entity-Level Controlled Image Generation + +||| +|-|-| +||| +||| +||| +### Image Inpainting +|Inpainting Input|Inpainting Output| +|-|-| +||| +||| +### Styled Entity Control +|Style Reference|Entity Control Variance 1|Entity Control Variance 2|Entity Control Variance 3| +|-|-|-|-| +||||| +||||| \ No newline at end of file From fd6e661203616d7b36bf20826f179eb15367b56f Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 17:50:20 +0800 Subject: [PATCH 4/8] update readme --- README.md | 6 +++++- examples/EntityControl/README.md | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 3053a92..9ed8664 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,11 @@ Until now, DiffSynth Studio has supported the following models: * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) ## News - +- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/README.md). + * Paper: Comming soon + * Github: https://github.com/modelscope/DiffSynth-Studio + * Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) + * Training dataset: Coming soon - **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details. diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index 14d716a..5a71375 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -37,6 +37,16 @@ See [./entity_control.py](./entity_control.py) for usage. ## Examples ### Entity-Level Controlled Image Generation +1. The effect of generating images with continuously changing entity positions. +