From e3d89cec0c8511e9e1e59508650355e0ce89c20a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 25 Dec 2024 17:19:31 +0800 Subject: [PATCH 1/8] temp commit for entity control --- diffsynth/models/flux_dit.py | 128 +++++++++++++----- diffsynth/pipelines/flux_image.py | 93 +++++++++++-- examples/EntityControl/entity_control_flux.py | 54 ++++++++ examples/EntityControl/entity_inpaint_flux.py | 59 ++++++++ examples/EntityControl/utils.py | 59 ++++++++ 5 files changed, 349 insertions(+), 44 deletions(-) create mode 100644 examples/EntityControl/entity_control_flux.py create mode 100644 examples/EntityControl/entity_inpaint_flux.py create mode 100644 examples/EntityControl/utils.py diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index faf58cd..8deea3b 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -40,7 +40,7 @@ class RoPEEmbedding(torch.nn.Module): n_axes = ids.shape[-1] emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) return emb.unsqueeze(1) - + class FluxJointAttention(torch.nn.Module): @@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states_a.shape[0] # Part A @@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] @@ -103,7 +103,7 @@ class FluxJointAttention(torch.nn.Module): else: hidden_states_b = self.b_to_out(hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxJointTransformerBlock(torch.nn.Module): @@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module): ) - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) # Attention - attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list) + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a @@ -147,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module): hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxSingleAttention(torch.nn.Module): @@ -184,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) return hidden_states - + class AdaLayerNormSingle(torch.nn.Module): @@ -200,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module): shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa - + class FluxSingleTransformerBlock(torch.nn.Module): @@ -225,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None): + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states.shape[0] qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) @@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) if ipadapter_kwargs_list is not None: @@ -243,21 +243,21 @@ class FluxSingleTransformerBlock(torch.nn.Module): return hidden_states - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): residual = hidden_states_a norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) hidden_states_a = self.to_qkv_mlp(norm_hidden_states) attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] - attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list) + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) hidden_states_a = residual + hidden_states_a - + return hidden_states_a, hidden_states_b - + class AdaLayerNormContinuous(torch.nn.Module): @@ -300,7 +300,7 @@ class FluxDiT(torch.nn.Module): def unpatchify(self, hidden_states, height, width): hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) return hidden_states - + def prepare_image_ids(self, latents): batch_size, _, height, width = latents.shape @@ -317,7 +317,7 @@ class FluxDiT(torch.nn.Module): latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) return latent_image_ids - + def tiled_forward( self, @@ -337,12 +337,45 @@ class FluxDiT(torch.nn.Module): ) return hidden_states + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask def forward( self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, - tiled=False, tile_size=128, tile_stride=64, + tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None, use_gradient_checkpointing=False, **kwargs ): @@ -353,46 +386,70 @@ class FluxDiT(torch.nn.Module): tile_size=tile_size, tile_stride=tile_stride, **kwargs ) - + if image_ids is None: image_ids = self.prepare_image_ids(hidden_states) - + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) if self.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = self.context_embedder(prompt_emb) - image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = self.patchify(hidden_states) hidden_states = self.x_embedder(hidden_states) - + + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - + for block in self.blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) for block in self.single_blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = hidden_states[:, prompt_emb.shape[1]:] hidden_states = self.final_norm_out(hidden_states, conditioning) @@ -400,7 +457,7 @@ class FluxDiT(torch.nn.Module): hidden_states = self.unpatchify(hidden_states, height, width) return hidden_states - + def quantize(self): def cast_to(weight, dtype=None, device=None, copy=False): @@ -440,16 +497,16 @@ class FluxDiT(torch.nn.Module): class Linear(torch.nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def forward(self,input,**kwargs): weight,bias= cast_bias_weight(self,input) return torch.nn.functional.linear(input,weight,bias) - + class RMSNorm(torch.nn.Module): def __init__(self, module): super().__init__() self.module = module - + def forward(self,hidden_states,**kwargs): weight= cast_weight(self.module,hidden_states) input_dtype = hidden_states.dtype @@ -457,7 +514,7 @@ class FluxDiT(torch.nn.Module): hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) hidden_states = hidden_states.to(input_dtype) * weight return hidden_states - + def replace_layer(model): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): @@ -483,7 +540,6 @@ class FluxDiT(torch.nn.Module): @staticmethod def state_dict_converter(): return FluxDiTStateDictConverter() - class FluxDiTStateDictConverter: @@ -587,7 +643,7 @@ class FluxDiTStateDictConverter: state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) return state_dict_ - + def from_civitai(self, state_dict): rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 42d142c..c012f09 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -139,6 +139,39 @@ class FluxImagePipeline(BasePipeline): images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] return torch.cat(images, dim=0) + def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.): + # inpaint noise + inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id] + # merge noise + weight = torch.ones_like(inpaint_noise) + inpaint_noise[fg_mask] = pred_noise[fg_mask] + inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight + weight[bg_mask] += background_weight + inpaint_noise /= weight + return inpaint_noise + + def preprocess_masks(self, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None): + fg_mask, bg_mask = None, None + if inpaint_input is not None: + from copy import deepcopy + masks_ = deepcopy(entity_masks) + fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_]) + fg_masks = (fg_masks > 0).float() + fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0 + bg_mask = ~fg_mask + entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0) + return entity_prompts, entity_masks, fg_mask, bg_mask + @torch.no_grad() def __call__( self, @@ -160,6 +193,10 @@ class FluxImagePipeline(BasePipeline): width=1024, num_inference_steps=30, t5_sequence_length=512, + inpaint_input=None, + entity_prompts=None, + entity_masks=None, + use_seperated_negtive_prompt=True, tiled=False, tile_size=128, tile_stride=64, @@ -176,12 +213,13 @@ class FluxImagePipeline(BasePipeline): self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if input_image is not None: + if input_image is not None or inpaint_input is not None: + input_image = input_image or inpaint_input self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) - latents = self.encode_image(image, **tiler_kwargs) + input_latents = self.encode_image(image, **tiler_kwargs) noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) - latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) @@ -195,6 +233,14 @@ class FluxImagePipeline(BasePipeline): prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] + # Entity control + negative_entity_prompts = None + negative_masks = None + if entity_masks is not None: + entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input) + if use_seperated_negtive_prompt and cfg_scale != 1.0: + negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1) + negative_masks = entity_masks # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) @@ -229,18 +275,20 @@ class FluxImagePipeline(BasePipeline): # Classifier-free guidance inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, + hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs ) + if inpaint_input: + noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id) if cfg_scale != 1.0: negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, + hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks, **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) @@ -278,6 +326,8 @@ def lets_dance_flux( tiled=False, tile_size=128, tile_stride=64, + entity_prompts=None, + entity_masks=None, ipadapter_kwargs_list={}, **kwargs ): @@ -333,13 +383,38 @@ def lets_dance_flux( if dit.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = dit.context_embedder(prompt_emb) - image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) hidden_states = dit.x_embedder(hidden_states) - + + # Entity Control + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + # Joint Blocks for block_id, block in enumerate(dit.blocks): hidden_states, prompt_emb = block( @@ -347,6 +422,7 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, + attention_mask, ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) # ControlNet if controlnet is not None and controlnet_frames is not None: @@ -361,6 +437,7 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, + attention_mask, ipadapter_kwargs_list=ipadapter_kwargs_list.get( block_id + num_joint_blocks, None)) # ControlNet diff --git a/examples/EntityControl/entity_control_flux.py b/examples/EntityControl/entity_control_flux.py new file mode 100644 index 0000000..5e74e0b --- /dev/null +++ b/examples/EntityControl/entity_control_flux.py @@ -0,0 +1,54 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image + +# lora_path = download_customized_models( +# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", +# origin_file_path="merged_lora.safetensors", +# local_dir="models/lora" +# )[0] + +lora_path = '/root/model_bf16.safetensors' +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", + "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", + "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +model_manager.load_lora(lora_path, lora_alpha=1.) + +pipe = FluxImagePipeline.from_model_manager(model_manager) + +mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' +image_shape = 1024 +guidance = 3.5 +cfg = 3.0 +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +names = ['row_2_1'] +seeds = [0] +# use this to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +for name, seed in zip(names, seeds): + out_dir = f'workdirs/entity_control/{name}' + os.makedirs(out_dir, exist_ok=True) + cur_dir = os.path.join(mask_dir, name) + metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) + for seed in range(3, 10): + prompt = metas['global_prompt'] + mask_prompts = metas['mask_prompts'] + masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] + torch.manual_seed(seed) + image = pipe( + prompt=prompt, + cfg_scale=cfg, + negative_prompt=negative_prompt, + num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape, + entity_prompts=mask_prompts, entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt + ) + use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' + visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/entity_inpaint_flux.py b/examples/EntityControl/entity_inpaint_flux.py new file mode 100644 index 0000000..5027563 --- /dev/null +++ b/examples/EntityControl/entity_inpaint_flux.py @@ -0,0 +1,59 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image + +# lora_path = download_customized_models( +# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", +# origin_file_path="merged_lora.safetensors", +# local_dir="models/lora" +# )[0] + +lora_path = '/root/model_bf16.safetensors' +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", + "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", + "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +model_manager.load_lora(lora_path, lora_alpha=1.) + +pipe = FluxImagePipeline.from_model_manager(model_manager) + +mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' +image_shape = 1024 +guidance = 3.5 +cfg = 3.0 +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +names = ['inpaint2'] +seeds = [0] +use_seperated_negtive_prompt = False +for name, seed in zip(names, seeds): + out_dir = f'workdirs/paper_app/inpaint/elc/{name}' + os.makedirs(out_dir, exist_ok=True) + cur_dir = os.path.join(mask_dir, name) + metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) + inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB') + prompt = metas['global_prompt'] + prompt = 'A person with a dog walking on the cloud. A rocket in the sky' + mask_prompts = metas['mask_prompts'] + masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] + torch.manual_seed(seed) + image = pipe( + prompt=prompt, + cfg_scale=cfg, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=guidance, + height=image_shape, + width=image_shape, + entity_prompts=mask_prompts, + entity_masks=masks, + inpaint_input=inpaint_input, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, + ) + use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' + visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/utils.py b/examples/EntityControl/utils.py new file mode 100644 index 0000000..573c8a6 --- /dev/null +++ b/examples/EntityControl/utils.py @@ -0,0 +1,59 @@ +from PIL import Image, ImageDraw, ImageFont +import random + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result From b6620f3dde192f994b094c71d61119f8f6e9addd Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 14:04:28 +0800 Subject: [PATCH 2/8] update_example entity control --- apps/gradio/entity_level_control.py | 320 ++++++++++++++++++ examples/EntityControl/README.md | 1 + examples/EntityControl/entity_control.py | 57 ++++ examples/EntityControl/entity_control_flux.py | 54 --- .../EntityControl/entity_control_ipadapter.py | 51 +++ examples/EntityControl/entity_inpaint.py | 58 ++++ examples/EntityControl/entity_inpaint_flux.py | 59 ---- 7 files changed, 487 insertions(+), 113 deletions(-) create mode 100644 apps/gradio/entity_level_control.py create mode 100644 examples/EntityControl/README.md create mode 100644 examples/EntityControl/entity_control.py delete mode 100644 examples/EntityControl/entity_control_flux.py create mode 100644 examples/EntityControl/entity_control_ipadapter.py create mode 100644 examples/EntityControl/entity_inpaint.py delete mode 100644 examples/EntityControl/entity_inpaint_flux.py diff --git a/apps/gradio/entity_level_control.py b/apps/gradio/entity_level_control.py new file mode 100644 index 0000000..d914cd0 --- /dev/null +++ b/apps/gradio/entity_level_control.py @@ -0,0 +1,320 @@ +import gradio as gr +from diffsynth import ModelManager, FluxImagePipeline +import os, torch +from PIL import Image +import numpy as np +from PIL import ImageDraw, ImageFont +import random +import json + +lora_checkpoint_path = 'models/lora/entity_control/model_bf16.safetensors' +save_masks_dir = 'workdirs/tmp_mask' + +def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): + save_dir = os.path.join(save_masks_dir, random_dir) + os.makedirs(save_dir, exist_ok=True) + for i, mask in enumerate(masks): + save_path = os.path.join(save_dir, f'{i}.png') + mask.save(save_path) + sample = { + "global_prompt": global_prompt, + "mask_prompts": mask_prompts, + "seed": seed, + } + with open(os.path.join(save_dir, f"prompts.json"), 'w') as f: + json.dump(sample, f, indent=4) + +def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + return result + +config = { + "model_config": { + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "default_parameters": { + "cfg_scale": 3.0, + "embedded_guidance": 3.5, + "num_inference_steps": 30, + } + }, + }, + "max_num_painter_layers": 8, + "max_num_model_cache": 1, +} + + +def load_model_list(model_type): + if model_type is None: + return [] + folder = config["model_config"][model_type]["model_folder"] + file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] + if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: + file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] + file_list = sorted(file_list) + return file_list + + +model_dict = {} + +def load_model(model_type, model_path): + global model_dict + model_key = f"{model_type}:{model_path}" + if model_key in model_dict: + return model_dict[model_key] + model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) + model_manager = ModelManager() + if model_type == "FLUX": + model_manager.torch_dtype = torch.bfloat16 + file_list = [ + os.path.join(model_path, "text_encoder/model.safetensors"), + os.path.join(model_path, "text_encoder_2"), + ] + for file_name in os.listdir(model_path): + if file_name.endswith(".safetensors"): + file_list.append(os.path.join(model_path, file_name)) + model_manager.load_models(file_list) + model_manager.load_lora(lora_checkpoint_path, lora_alpha=1.) + + else: + model_manager.load_model(model_path) + pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) + while len(model_dict) + 1 > config["max_num_model_cache"]: + key = next(iter(model_dict.keys())) + model_manager_to_release, _ = model_dict[key] + model_manager_to_release.to("cpu") + del model_dict[key] + torch.cuda.empty_cache() + model_dict[model_key] = model_manager, pipe + return model_manager, pipe + + +with gr.Blocks() as app: + gr.Markdown(""" + # 实体级控制文生图模型EliGen + **UI说明** + 1. 点击Load model读取模型,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。 + 2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。 + 3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。 + 4. 尽情创造! + """) + gr.Markdown(""" + # Entity-Level Controlled Text-to-Image Model: EliGen + **UI Instructions** + 1. Click "Load model" to load the model. The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. + 2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts. + 3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing. + 4. Enjoy! + """) + with gr.Row(): + random_mask_dir = gr.State('') + with gr.Column(scale=382, min_width=100): + model_type = gr.State('FLUX') + model_path = gr.State('FLUX.1-dev') + with gr.Accordion(label="Model"): + load_model_button = gr.Button(value="Load model") + + with gr.Accordion(label="Global prompt"): + prompt = gr.Textbox(label="Prompt", lines=3) + negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", lines=1) + cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale") + embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale") + + with gr.Accordion(label="Inference Options"): + num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps") + height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") + width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") + return_with_mask = gr.Checkbox(value=True, interactive=True, label="show result with mask painting") + with gr.Column(): + use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed") + seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False) + with gr.Accordion(label="Inpaint Input Image (Testing)"): + input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") + background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight") + + with gr.Column(): + reset_input_button = gr.Button(value="Reset Inpaint Input") + send_input_to_painter = gr.Button(value="Set as painter's background") + @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) + def reset_input_image(input_image): + return None + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask], + outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, load_model_button, random_mask_dir], + triggers=load_model_button.click + ) + def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask): + load_model(model_type, model_path) + cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale) + embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) + num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) + height = config["model_config"][model_type]["default_parameters"].get("height", height) + width = config["model_config"][model_type]["default_parameters"].get("width", width) + return_with_mask = config["model_config"][model_type]["default_parameters"].get("return_with_mask", return_with_mask) + return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, gr.update(value="Loaded FLUX"), gr.State(f'{random.randint(0, 1000000):08d}') + + + with gr.Column(scale=618, min_width=100): + with gr.Accordion(label="Painter"): + enable_local_prompt_list = [] + local_prompt_list = [] + mask_scale_list = [] + canvas_list = [] + for painter_layer_id in range(config["max_num_painter_layers"]): + with gr.Tab(label=f"Layer {painter_layer_id}"): + enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}") + local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") + mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}") + canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA", + brush=gr.Brush(default_size=30, default_color="#000000", colors=["#000000"]), + label="Painter", key=f"canvas_{painter_layer_id}") + @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden") + def resize_canvas(height, width, canvas): + h, w = canvas["background"].shape[:2] + if h != height or width != w: + return np.ones((height, width, 3), dtype=np.uint8) * 255 + else: + return canvas + + enable_local_prompt_list.append(enable_local_prompt) + local_prompt_list.append(local_prompt) + mask_scale_list.append(mask_scale) + canvas_list.append(canvas) + with gr.Accordion(label="Results"): + run_button = gr.Button(value="Generate", variant="primary") + output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") + with gr.Row(): + with gr.Column(): + output_to_painter_button = gr.Button(value="Set as painter's background") + with gr.Column(): + output_to_input_button = gr.Button(value="Set as input image") + real_output = gr.State(None) + mask_out = gr.State(None) + + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list, + outputs=[output_image, real_output, mask_out], + triggers=run_button.click + ) + def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): + print("save to:", random_mask_dir.value) + _, pipe = load_model(model_type, model_path) + input_params = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "cfg_scale": cfg_scale, + "num_inference_steps": num_inference_steps, + "height": height, + "width": width, + "progress_bar_cmd": progress.tqdm, + } + if isinstance(pipe, FluxImagePipeline): + input_params["embedded_guidance"] = embedded_guidance + if input_image is not None: + input_params["inpaint_input"] = input_image.resize((width, height)).convert("RGB") + + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( + args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], + args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], + args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]], + args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]] + ) + local_prompts, masks, mask_scales = [], [], [] + for enable_local_prompt, local_prompt, mask_scale, canvas in zip( + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list + ): + if enable_local_prompt: + local_prompts.append(local_prompt) + masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) + mask_scales.append(mask_scale) + entity_masks = None if len(masks) == 0 else masks + entity_prompts = None if len(local_prompts) == 0 else local_prompts + input_params.update({ + "entity_prompts": entity_prompts, + "entity_masks": entity_masks, + }) + torch.manual_seed(seed) + image = pipe(**input_params) + # visualize masks + masks = [mask.resize(image.size) for mask in masks] + image_with_mask = visualize_masks(image, masks, local_prompts) + # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir.value) + + real_output = gr.State(image) + mask_out = gr.State(image_with_mask) + + if return_with_mask: + return image_with_mask, real_output, mask_out + return image, real_output, mask_out + + @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) + def send_input_to_painter_background(input_image, *canvas_list): + if input_image is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = input_image.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) + def send_output_to_painter_background(real_output, *canvas_list): + if real_output is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = real_output.value.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") + def show_output(return_with_mask, real_output, mask_out): + if return_with_mask: + return mask_out.value + else: + return real_output.value + @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) + def send_output_to_pipe_input(real_output): + return real_output.value + +app.launch() diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md new file mode 100644 index 0000000..784c590 --- /dev/null +++ b/examples/EntityControl/README.md @@ -0,0 +1 @@ +# EliGen: Entity-Level Controlled Image Generation diff --git a/examples/EntityControl/entity_control.py b/examples/EntityControl/entity_control.py new file mode 100644 index 0000000..2735682 --- /dev/null +++ b/examples/EntityControl/entity_control.py @@ -0,0 +1,57 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import requests +from io import BytesIO + +# download and load model +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 4 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/02905f6e-40c2-4482-9abe-b1ce50ccabbf', + 'https://github.com/user-attachments/assets/a4cf4361-abf7-4556-ba94-74683eda4cb7', + 'https://github.com/user-attachments/assets/b6595ff4-7269-4d8f-acf0-5df40bd6c59f', + 'https://github.com/user-attachments/assets/941d39a7-3aa1-437f-8b2a-4adb15d2fb3e', + 'https://github.com/user-attachments/assets/400c4086-5398-4291-b1b5-22d8483c08d9', + 'https://github.com/user-attachments/assets/ce324c77-fa1d-4aad-a5cb-698f0d5eca70', + 'https://github.com/user-attachments/assets/4e62325f-a60c-44f7-b53b-6da0869bb9db' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" + +# generate image +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + height=image_shape, + width=image_shape, + entity_prompts=entity_prompts, + entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, +) +image.save(f"entity_control.png") +visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_control_flux.py b/examples/EntityControl/entity_control_flux.py deleted file mode 100644 index 5e74e0b..0000000 --- a/examples/EntityControl/entity_control_flux.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline -from examples.EntityControl.utils import visualize_masks -import os -import json -from PIL import Image - -# lora_path = download_customized_models( -# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", -# origin_file_path="merged_lora.safetensors", -# local_dir="models/lora" -# )[0] - -lora_path = '/root/model_bf16.safetensors' -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") -model_manager.load_models([ - "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", - "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", - "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" -]) -model_manager.load_lora(lora_path, lora_alpha=1.) - -pipe = FluxImagePipeline.from_model_manager(model_manager) - -mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' -image_shape = 1024 -guidance = 3.5 -cfg = 3.0 -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," -names = ['row_2_1'] -seeds = [0] -# use this to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -for name, seed in zip(names, seeds): - out_dir = f'workdirs/entity_control/{name}' - os.makedirs(out_dir, exist_ok=True) - cur_dir = os.path.join(mask_dir, name) - metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) - for seed in range(3, 10): - prompt = metas['global_prompt'] - mask_prompts = metas['mask_prompts'] - masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] - torch.manual_seed(seed) - image = pipe( - prompt=prompt, - cfg_scale=cfg, - negative_prompt=negative_prompt, - num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape, - entity_prompts=mask_prompts, entity_masks=masks, - use_seperated_negtive_prompt=use_seperated_negtive_prompt - ) - use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' - visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/entity_control_ipadapter.py b/examples/EntityControl/entity_control_ipadapter.py new file mode 100644 index 0000000..76ae869 --- /dev/null +++ b/examples/EntityControl/entity_control_ipadapter.py @@ -0,0 +1,51 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import requests +from io import BytesIO + +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 4 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/e6745b3f-ab2b-4612-9bb5-b7235474a9a4', + 'https://github.com/user-attachments/assets/5ddf9a89-32fa-4540-89ad-e956130942b3', + 'https://github.com/user-attachments/assets/9d8a0bb0-6817-497e-af85-44f2512afe79' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ['A girl', 'hat', 'sunset'] +global_prompt = "A girl wearing a hat, looking at the sunset" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" + +response = requests.get('https://github.com/user-attachments/assets/019bbfaa-04b3-4de6-badb-32b67c29a1bc') +reference_img = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) + +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, embedded_guidance=3.5, height=image_shape, width=image_shape, + entity_prompts=entity_prompts, entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, + ipadapter_images=[reference_img], ipadapter_scale=0.7 +) +image.save(f"styled_entity_control.png") +visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_inpaint.py b/examples/EntityControl/entity_inpaint.py new file mode 100644 index 0000000..eae4641 --- /dev/null +++ b/examples/EntityControl/entity_inpaint.py @@ -0,0 +1,58 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image +import requests +from io import BytesIO + +# download and load model +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 0 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/0cf78663-5314-4280-a065-31ded7a24a46', + 'https://github.com/user-attachments/assets/bd3938b8-72a8-4d56-814f-f6445971b91d' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ["A person wear red shirt", "Airplane"] +global_prompt = "A person walking on the path in front of a house; An airplane in the sky" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur" + +response = requests.get('https://github.com/user-attachments/assets/fa4d6ba5-08fd-4fc7-adbb-19898d839364') +inpaint_input = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) + +# generate image +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + height=image_shape, + width=image_shape, + entity_prompts=entity_prompts, + entity_masks=masks, + inpaint_input=inpaint_input, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, +) +image.save(f"entity_inpaint.png") +visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") \ No newline at end of file diff --git a/examples/EntityControl/entity_inpaint_flux.py b/examples/EntityControl/entity_inpaint_flux.py deleted file mode 100644 index 5027563..0000000 --- a/examples/EntityControl/entity_inpaint_flux.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline -from examples.EntityControl.utils import visualize_masks -import os -import json -from PIL import Image - -# lora_path = download_customized_models( -# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", -# origin_file_path="merged_lora.safetensors", -# local_dir="models/lora" -# )[0] - -lora_path = '/root/model_bf16.safetensors' -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") -model_manager.load_models([ - "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", - "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", - "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" -]) -model_manager.load_lora(lora_path, lora_alpha=1.) - -pipe = FluxImagePipeline.from_model_manager(model_manager) - -mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' -image_shape = 1024 -guidance = 3.5 -cfg = 3.0 -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," -names = ['inpaint2'] -seeds = [0] -use_seperated_negtive_prompt = False -for name, seed in zip(names, seeds): - out_dir = f'workdirs/paper_app/inpaint/elc/{name}' - os.makedirs(out_dir, exist_ok=True) - cur_dir = os.path.join(mask_dir, name) - metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) - inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB') - prompt = metas['global_prompt'] - prompt = 'A person with a dog walking on the cloud. A rocket in the sky' - mask_prompts = metas['mask_prompts'] - masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] - torch.manual_seed(seed) - image = pipe( - prompt=prompt, - cfg_scale=cfg, - negative_prompt=negative_prompt, - num_inference_steps=50, - embedded_guidance=guidance, - height=image_shape, - width=image_shape, - entity_prompts=mask_prompts, - entity_masks=masks, - inpaint_input=inpaint_input, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, - ) - use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' - visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) From c087f68d74abdd106c407f5b8f3786bc34088b53 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 17:08:44 +0800 Subject: [PATCH 3/8] update readme --- examples/EntityControl/README.md | 53 ++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index 784c590..14d716a 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -1 +1,54 @@ # EliGen: Entity-Level Controlled Image Generation + +## Introduction + +We propose EliGen, a novel approach that leverages fine-grained entity-level information to enable precise and controllable text-to-image generation. EliGen excels in tasks such as entity-level controlled image generation and image inpainting, while its applicability is not limited to these areas. Additionally, it can be seamlessly integrated with existing community models, such as the IP-Adpater. + +* Paper: Comming soon +* Github: https://github.com/modelscope/DiffSynth-Studio +* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) +* Training dataset: Coming soon + +## Methodology + +![regional-attention](https://github.com/user-attachments/assets/9a147201-15ab-421f-a6c5-701075754478) + +We introduce a regional attention mechanism within the DiT framework to effectively process the conditions of each entity. This mechanism enables the local prompt associated with each entity to semantically influence specific regions through regional attention. To further enhance the layout control capabilities of EliGen, we meticulously curate an entity-annotated dataset and fine-tune the model using the LoRA framework. + +1. **Regional Attention**: Regional attention is shown in above figure, which can be easily applied to other text-to-image models. Its core principle involves transforming the positional information of each entity into an attention mask, ensuring that the mechanism only affects the designated regions. + +2. **Dataset with Entity Annotation**: To curate a dedicated entity control dataset, we start by randomly selecting captions from DiffusionDB and generating the corresponding source image using Flux. Next, we employ Qwen2-VL 72B, recognized for its advanced grounding capabilities among MLLMs, to randomly identify entities within the image. These entities are annotated with local prompts and bounding boxes for precise localization, forming the foundation of our dataset for further training. + +3. **Training**: We apply LoRA and deepspeed to finetune regional attention with curated dataset, enabling our EliGen performing effective entity-level control. + +## Usage +1. **Entity-Level Controlled Image Generation** +See [./entity_control.py](./entity_control.py) for usage. +2. **Image Inpainting** + To apply EliGen to image inpainting task, we propose a inpainting fusion pipeline to preserve the non-painting areas while enabling precise, entity-level modifications over inpaining regions. + See [./entity_inpaint.py](./entity_inpaint.py) for usage. +3. **Styled Entity Control** + EliGen can be seamlessly integrated with existing community models. We have provided an example of how to integrate it with the IP-Adpater. See [./entity_control_ipadapter.py](./entity_control_ipadapter.py) for usage. +4. **Play with EliGen using UI** + Download the checkpoint of EliGen from [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) to `models/lora/entity_control` and run the following command to try interactive UI: + ```bash + python apps/gradio/entity_level_control.py + ``` +## Examples +### Entity-Level Controlled Image Generation + +|![image_1_base](https://github.com/user-attachments/assets/b8564b28-19b5-424f-bf3c-6476f2923ff9)|![image_1_base](https://github.com/user-attachments/assets/20793715-42d3-46f7-8d62-0cb4cacef38d)| +|-|-| +|![image_1_base](https://github.com/user-attachments/assets/70ef12fe-d300-4b52-9d11-eabc9b5464a8)|![image_1_enhance](https://github.com/user-attachments/assets/7645ce0e-4aa7-4b1e-b7a7-bccfd9796461)| +|![image_2_base](https://github.com/user-attachments/assets/2f1e44e1-8f1f-4c6e-ab7a-1b6861a33a69)|![image_2_enhance](https://github.com/user-attachments/assets/faf78498-57ba-41bd-b516-570c86984515)| +|![image_3_base](https://github.com/user-attachments/assets/206d1cef-2e96-4469-aed5-cdeb06ab9e99)|![image_3_enhance](https://github.com/user-attachments/assets/75d784d6-d5a1-474f-a5d5-ef8074135f35)| +### Image Inpainting +|Inpainting Input|Inpainting Output| +|-|-| +|![image_1_base](https://github.com/user-attachments/assets/5f74c710-bf30-4db1-ae40-a1e1995ccef6)|![image_1_enhance](https://github.com/user-attachments/assets/1cd71177-e956-46d3-86ce-06f774c96efd)| +|![image_2_base](https://github.com/user-attachments/assets/5ef499f3-3d8a-49cc-8ceb-86af7f5cb9f8)|![image_2_enhance](https://github.com/user-attachments/assets/fb967035-7b28-466c-a753-c00135559121)| +### Styled Entity Control +|Style Reference|Entity Control Variance 1|Entity Control Variance 2|Entity Control Variance 3| +|-|-|-|-| +|![image_1_base](https://github.com/user-attachments/assets/5e2dd3ab-37d3-4f58-8e02-ee2f9b238604)|![image_1_enhance](https://github.com/user-attachments/assets/0f6711a2-572a-41b3-938a-95deff6d732d)|![image_1_enhance](https://github.com/user-attachments/assets/ce2e66e5-1fdf-44e8-bca7-555d805a50b1)|![image_1_enhance](https://github.com/user-attachments/assets/ad2da233-2f7c-4065-ab57-b2d84dc2c0e2)| +|![image_2_base](https://github.com/user-attachments/assets/77cf7ceb-48e3-442d-8ffc-5fa4a10fe81a)|![image_2_enhance](https://github.com/user-attachments/assets/59a4f3c2-e59d-40c7-886c-0768f14fcc89)|![image_2_enhance](https://github.com/user-attachments/assets/a9187fb0-489a-49c9-a52f-56b1bd96faf7)|![image_2_enhance](https://github.com/user-attachments/assets/a62caee4-3863-4b56-96ff-e0785c6d93bb)| \ No newline at end of file From fd6e661203616d7b36bf20826f179eb15367b56f Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 17:50:20 +0800 Subject: [PATCH 4/8] update readme --- README.md | 6 +++++- examples/EntityControl/README.md | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 3053a92..9ed8664 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,11 @@ Until now, DiffSynth Studio has supported the following models: * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) ## News - +- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/README.md). + * Paper: Comming soon + * Github: https://github.com/modelscope/DiffSynth-Studio + * Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) + * Training dataset: Coming soon - **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details. diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index 14d716a..5a71375 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -37,6 +37,16 @@ See [./entity_control.py](./entity_control.py) for usage. ## Examples ### Entity-Level Controlled Image Generation +1. The effect of generating images with continuously changing entity positions. +
+ +
+ +2. The image generation effect of complex Entity combinations, demonstrating the strong generalization of EliGen. + |![image_1_base](https://github.com/user-attachments/assets/b8564b28-19b5-424f-bf3c-6476f2923ff9)|![image_1_base](https://github.com/user-attachments/assets/20793715-42d3-46f7-8d62-0cb4cacef38d)| |-|-| |![image_1_base](https://github.com/user-attachments/assets/70ef12fe-d300-4b52-9d11-eabc9b5464a8)|![image_1_enhance](https://github.com/user-attachments/assets/7645ce0e-4aa7-4b1e-b7a7-bccfd9796461)| From 9853f8345402a3a95d51c80a183ab51e90097a2c Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 18:02:49 +0800 Subject: [PATCH 5/8] update readme video --- examples/EntityControl/README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index 5a71375..f8c21d4 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -38,12 +38,10 @@ See [./entity_control.py](./entity_control.py) for usage. ### Entity-Level Controlled Image Generation 1. The effect of generating images with continuously changing entity positions. -
- -
+ 2. The image generation effect of complex Entity combinations, demonstrating the strong generalization of EliGen. From 2872fdaf489e7fbdc0b13e09fc7a0f4c9ae701ad Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 18:09:29 +0800 Subject: [PATCH 6/8] update video of entity control --- examples/EntityControl/README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index f8c21d4..94c4cb5 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -38,10 +38,8 @@ See [./entity_control.py](./entity_control.py) for usage. ### Entity-Level Controlled Image Generation 1. The effect of generating images with continuously changing entity positions. - + +https://github.com/user-attachments/assets/4fc76df1-b26a-46e8-a950-865cdf02a38d 2. The image generation effect of complex Entity combinations, demonstrating the strong generalization of EliGen. From 6f743fc4b6c39aa56a2ebed55a46f25fac48c585 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 2 Jan 2025 19:54:09 +0800 Subject: [PATCH 7/8] refine code --- diffsynth/models/flux_dit.py | 64 +++-- diffsynth/models/model_manager.py | 26 +- diffsynth/pipelines/flux_image.py | 240 ++++++++++-------- examples/EntityControl/entity_control.py | 56 ++-- .../EntityControl/entity_control_ipadapter.py | 61 ++--- examples/EntityControl/entity_inpaint.py | 63 ++--- 6 files changed, 263 insertions(+), 247 deletions(-) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 8deea3b..d592e61 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -337,6 +337,7 @@ class FluxDiT(torch.nn.Module): ) return hidden_states + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): N = len(entity_masks) batch_size = entity_masks[0].shape[0] @@ -371,11 +372,41 @@ class FluxDiT(torch.nn.Module): attention_mask[attention_mask == 1] = 0 return attention_mask + + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids): + repeat_dim = hidden_states.shape[1] + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + return prompt_emb, image_rotary_emb, attention_mask + + def forward( self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, - tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None, + tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None, use_gradient_checkpointing=False, **kwargs ): @@ -395,35 +426,16 @@ class FluxDiT(torch.nn.Module): guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) - repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = self.patchify(hidden_states) hidden_states = self.x_embedder(hidden_states) - max_masks = 0 - attention_mask = None - prompt_embs = [prompt_emb] - if entity_masks is not None: - # entity_masks - batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] - entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] - # global mask - global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) - entity_masks = entity_masks + [global_mask] # append global to last - # attention mask - attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) - attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) - attention_mask = attention_mask.unsqueeze(1) - # embds: n_masks * b * seq * d - local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] - prompt_embs = local_embs + prompt_embs # append global to last - prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] - prompt_emb = torch.cat(prompt_embs, dim=1) - - # positional embedding - text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) - image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + else: + prompt_emb = self.context_embedder(prompt_emb) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None def create_custom_forward(module): def custom_forward(*inputs): diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index dcee6d3..96ee86a 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -366,17 +366,21 @@ class ModelManager: def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): - print(f"Loading LoRA models from file: {file_path}") - if len(state_dict) == 0: - state_dict = load_state_dict(file_path) - for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): - for lora in get_lora_loaders(): - match_results = lora.match(model, state_dict) - if match_results is not None: - print(f" Adding LoRA to {model_name} ({model_path}).") - lora_prefix, model_resource = match_results - lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) - break + if isinstance(file_path, list): + for file_path_ in file_path: + self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) + else: + print(f"Loading LoRA models from file: {file_path}") + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): + for lora in get_lora_loaders(): + match_results = lora.match(model, state_dict) + if match_results is not None: + print(f" Adding LoRA to {model_name} ({model_path}).") + lora_prefix, model_resource = match_results + lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) + break def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index c012f09..b6fac68 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -10,6 +10,7 @@ import numpy as np from PIL import Image from ..models.tiler import FastTileWorker from transformers import SiglipVisionModel +from copy import deepcopy class FluxImagePipeline(BasePipeline): @@ -59,6 +60,7 @@ class FluxImagePipeline(BasePipeline): self.ipadapter = model_manager.fetch_model("flux_ipadapter") self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") + @staticmethod def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): pipe = FluxImagePipeline( @@ -133,12 +135,14 @@ class FluxImagePipeline(BasePipeline): # store it controlnet_frames.append(image) return controlnet_frames - + + def prepare_ipadapter_inputs(self, images, height=384, width=384): images = [image.convert("RGB").resize((width, height), resample=3) for image in images] images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] return torch.cat(images, dim=0) + def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.): # inpaint noise inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id] @@ -150,6 +154,7 @@ class FluxImagePipeline(BasePipeline): inpaint_noise /= weight return inpaint_noise + def preprocess_masks(self, masks, height, width, dim): out_masks = [] for mask in masks: @@ -158,10 +163,10 @@ class FluxImagePipeline(BasePipeline): out_masks.append(mask) return out_masks - def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None): + + def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False): fg_mask, bg_mask = None, None - if inpaint_input is not None: - from copy import deepcopy + if enable_eligen_inpaint: masks_ = deepcopy(entity_masks) fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_]) fg_masks = (fg_masks > 0).float() @@ -172,35 +177,114 @@ class FluxImagePipeline(BasePipeline): entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0) return entity_prompts, entity_masks, fg_mask, bg_mask + + def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride): + if input_image is not None: + self.load_models_to_device(['vae_encoder']) + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + input_latents = None + return latents, input_latents + + + def prepare_ipadapter(self, ipadapter_images, ipadapter_scale): + if ipadapter_images is not None: + self.load_models_to_device(['ipadapter_image_encoder']) + ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images) + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output + self.load_models_to_device(['ipadapter']) + ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} + ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} + return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega + + + def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative): + if controlnet_image is not None: + self.load_models_to_device(['vae_encoder']) + controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} + if len(masks) > 0 and controlnet_inpaint_mask is not None: + print("The controlnet_inpaint_mask will be overridden by masks.") + local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks] + else: + local_controlnet_kwargs = None + else: + controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks) + controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {} + return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs + + + def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale): + if eligen_entity_masks is not None: + entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, eligen_entity_masks.shape[1], 1, 1) + entity_masks_nega = eligen_entity_masks + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + else: + entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None + fg_mask, bg_mask = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask + + + def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale): + # Extend prompt + self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) + prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) + + # Encode prompts + prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None + prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] + return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals + + @torch.no_grad() def __call__( self, + # Prompt prompt, - local_prompts=None, - masks=None, - mask_scales=None, negative_prompt="", cfg_scale=1.0, embedded_guidance=3.5, + t5_sequence_length=512, + # Image input_image=None, - ipadapter_images=None, - ipadapter_scale=1.0, - controlnet_image=None, - controlnet_inpaint_mask=None, - enable_controlnet_on_negative=False, denoising_strength=1.0, height=1024, width=1024, + seed=None, + # Steps num_inference_steps=30, - t5_sequence_length=512, - inpaint_input=None, - entity_prompts=None, - entity_masks=None, - use_seperated_negtive_prompt=True, + # local prompts + local_prompts=(), + masks=(), + mask_scales=(), + # ControlNet + controlnet_image=None, + controlnet_inpaint_mask=None, + enable_controlnet_on_negative=False, + # IP-Adapter + ipadapter_images=None, + ipadapter_scale=1.0, + # EliGen + eligen_entity_prompts=None, + eligen_entity_masks=None, + enable_eligen_on_negative=False, + enable_eligen_inpaint=False, + # Tile tiled=False, tile_size=128, tile_stride=64, - seed=None, + # Progress bar progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -213,83 +297,50 @@ class FluxImagePipeline(BasePipeline): self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if input_image is not None or inpaint_input is not None: - input_image = input_image or inpaint_input - self.load_models_to_device(['vae_encoder']) - image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) - input_latents = self.encode_image(image, **tiler_kwargs) - noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) - latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) - else: - latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride) - # Extend prompt - self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) - prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) + # Prompt + prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale) - # Encode prompts - prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) - if cfg_scale != 1.0: - prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) - prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] - - # Entity control - negative_entity_prompts = None - negative_masks = None - if entity_masks is not None: - entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input) - if use_seperated_negtive_prompt and cfg_scale != 1.0: - negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1) - negative_masks = entity_masks # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) - # IP-Adapter - if ipadapter_images is not None: - self.load_models_to_device(['ipadapter_image_encoder']) - ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images) - ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output - self.load_models_to_device(['ipadapter']) - ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} - ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} - else: - ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} + # Entity control + eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale) - # Prepare ControlNets - if controlnet_image is not None: - self.load_models_to_device(['vae_encoder']) - controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} - if len(masks) > 0 and controlnet_inpaint_mask is not None: - print("The controlnet_inpaint_mask will be overridden by masks.") - local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks] - else: - local_controlnet_kwargs = None - else: - controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks) + # IP-Adapter + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale) + + # ControlNets + controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) # Denoise self.load_models_to_device(['dit', 'controlnet']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) - # Classifier-free guidance + # Positive side inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, + hidden_states=latents, timestep=timestep, + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, - special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs + special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs ) - if inpaint_input: + + # Inpaint + if enable_eligen_inpaint: noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id) + + # Classifier-free guidance if cfg_scale != 1.0: - negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} + # Negative side noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, + hidden_states=latents, timestep=timestep, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -304,7 +355,7 @@ class FluxImagePipeline(BasePipeline): # Decode image self.load_models_to_device(['vae_decoder']) - image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.decode_image(latents, **tiler_kwargs) # Offload all models self.load_models_to_device([]) @@ -326,7 +377,7 @@ def lets_dance_flux( tiled=False, tile_size=128, tile_stride=64, - entity_prompts=None, + entity_prompt_emb=None, entity_masks=None, ipadapter_kwargs_list={}, **kwargs @@ -384,36 +435,16 @@ def lets_dance_flux( guidance = guidance * 1000 conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) - repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) hidden_states = dit.x_embedder(hidden_states) - # Entity Control - max_masks = 0 - attention_mask = None - prompt_embs = [prompt_emb] - if entity_masks is not None: - # entity_masks - batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] - entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] - # global mask - global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) - entity_masks = entity_masks + [global_mask] # append global to last - # attention mask - attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) - attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) - attention_mask = attention_mask.unsqueeze(1) - # embds: n_masks * b * seq * d - local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] - prompt_embs = local_embs + prompt_embs # append global to last - prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs] - prompt_emb = torch.cat(prompt_embs, dim=1) - - # positional embedding - text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) - image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None # Joint Blocks for block_id, block in enumerate(dit.blocks): @@ -423,7 +454,8 @@ def lets_dance_flux( conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states = hidden_states + controlnet_res_stack[block_id] @@ -438,8 +470,8 @@ def lets_dance_flux( conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get( - block_id + num_joint_blocks, None)) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] diff --git a/examples/EntityControl/entity_control.py b/examples/EntityControl/entity_control.py index 2735682..f505b94 100644 --- a/examples/EntityControl/entity_control.py +++ b/examples/EntityControl/entity_control.py @@ -1,57 +1,43 @@ -import torch from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download from examples.EntityControl.utils import visualize_masks from PIL import Image -import requests -from io import BytesIO +import torch + # download and load model -lora_path = download_customized_models( - model_id="DiffSynth-Studio/Eligen", - origin_file_path="model_bf16.safetensors", - local_dir="models/lora/entity_control" -)[0] model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) -model_manager.load_lora(lora_path, lora_alpha=1.) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) pipe = FluxImagePipeline.from_model_manager(model_manager) -# prepare inputs -image_shape = 1024 -seed = 4 -# set True to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -mask_urls = [ - 'https://github.com/user-attachments/assets/02905f6e-40c2-4482-9abe-b1ce50ccabbf', - 'https://github.com/user-attachments/assets/a4cf4361-abf7-4556-ba94-74683eda4cb7', - 'https://github.com/user-attachments/assets/b6595ff4-7269-4d8f-acf0-5df40bd6c59f', - 'https://github.com/user-attachments/assets/941d39a7-3aa1-437f-8b2a-4adb15d2fb3e', - 'https://github.com/user-attachments/assets/400c4086-5398-4291-b1b5-22d8483c08d9', - 'https://github.com/user-attachments/assets/ce324c77-fa1d-4aad-a5cb-698f0d5eca70', - 'https://github.com/user-attachments/assets/4e62325f-a60c-44f7-b53b-6da0869bb9db' -] -# prepare entity masks, entity prompts, global prompt and negative prompt -masks = [] -for url in mask_urls: - response = requests.get(url) - mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) - masks.append(mask) +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/mask*") +masks = [Image.open(f"./data/examples/eligen/mask{i}.png") for i in range(1, 8)] + entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" # generate image -torch.manual_seed(seed) image = pipe( prompt=global_prompt, cfg_scale=3.0, negative_prompt=negative_prompt, num_inference_steps=50, embedded_guidance=3.5, - height=image_shape, - width=image_shape, - entity_prompts=entity_prompts, - entity_masks=masks, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, + seed=4, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, ) image.save(f"entity_control.png") visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_control_ipadapter.py b/examples/EntityControl/entity_control_ipadapter.py index 76ae869..c604bad 100644 --- a/examples/EntityControl/entity_control_ipadapter.py +++ b/examples/EntityControl/entity_control_ipadapter.py @@ -1,51 +1,46 @@ -import torch from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download from examples.EntityControl.utils import visualize_masks from PIL import Image -import requests -from io import BytesIO +import torch -lora_path = download_customized_models( - model_id="DiffSynth-Studio/Eligen", - origin_file_path="model_bf16.safetensors", - local_dir="models/lora/entity_control" -)[0] + +# download and load model model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"]) -model_manager.load_lora(lora_path, lora_alpha=1.) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) pipe = FluxImagePipeline.from_model_manager(model_manager) -# prepare inputs -image_shape = 1024 -seed = 4 -# set True to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -mask_urls = [ - 'https://github.com/user-attachments/assets/e6745b3f-ab2b-4612-9bb5-b7235474a9a4', - 'https://github.com/user-attachments/assets/5ddf9a89-32fa-4540-89ad-e956130942b3', - 'https://github.com/user-attachments/assets/9d8a0bb0-6817-497e-af85-44f2512afe79' -] -# prepare entity masks, entity prompts, global prompt and negative prompt -masks = [] -for url in mask_urls: - response = requests.get(url) - mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) - masks.append(mask) +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/ipadapter*") +masks = [Image.open(f"./data/examples/eligen/ipadapter_mask_{i}.png") for i in range(1, 4)] + entity_prompts = ['A girl', 'hat', 'sunset'] global_prompt = "A girl wearing a hat, looking at the sunset" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" +reference_img = Image.open("./data/examples/eligen/ipadapter_image.png") -response = requests.get('https://github.com/user-attachments/assets/019bbfaa-04b3-4de6-badb-32b67c29a1bc') -reference_img = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) - -torch.manual_seed(seed) +# generate image image = pipe( prompt=global_prompt, cfg_scale=3.0, negative_prompt=negative_prompt, - num_inference_steps=50, embedded_guidance=3.5, height=image_shape, width=image_shape, - entity_prompts=entity_prompts, entity_masks=masks, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, - ipadapter_images=[reference_img], ipadapter_scale=0.7 + num_inference_steps=50, + embedded_guidance=3.5, + seed=4, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, + ipadapter_images=[reference_img], + ipadapter_scale=0.7 ) image.save(f"styled_entity_control.png") visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_inpaint.py b/examples/EntityControl/entity_inpaint.py index eae4641..d62da4f 100644 --- a/examples/EntityControl/entity_inpaint.py +++ b/examples/EntityControl/entity_inpaint.py @@ -1,58 +1,45 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download from examples.EntityControl.utils import visualize_masks -import os -import json from PIL import Image -import requests -from io import BytesIO +import torch # download and load model -lora_path = download_customized_models( - model_id="DiffSynth-Studio/Eligen", - origin_file_path="model_bf16.safetensors", - local_dir="models/lora/entity_control" -)[0] model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) -model_manager.load_lora(lora_path, lora_alpha=1.) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) pipe = FluxImagePipeline.from_model_manager(model_manager) -# prepare inputs -image_shape = 1024 -seed = 0 -# set True to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -mask_urls = [ - 'https://github.com/user-attachments/assets/0cf78663-5314-4280-a065-31ded7a24a46', - 'https://github.com/user-attachments/assets/bd3938b8-72a8-4d56-814f-f6445971b91d' -] -# prepare entity masks, entity prompts, global prompt and negative prompt -masks = [] -for url in mask_urls: - response = requests.get(url) - mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) - masks.append(mask) +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/inpaint*") +masks = [Image.open(f"./data/examples/eligen/inpaint_mask_{i}.png") for i in range(1, 3)] +input_image = Image.open("./data/examples/eligen/inpaint_image.jpg") + entity_prompts = ["A person wear red shirt", "Airplane"] global_prompt = "A person walking on the path in front of a house; An airplane in the sky" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur" -response = requests.get('https://github.com/user-attachments/assets/fa4d6ba5-08fd-4fc7-adbb-19898d839364') -inpaint_input = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) - # generate image -torch.manual_seed(seed) image = pipe( prompt=global_prompt, + input_image=input_image, cfg_scale=3.0, negative_prompt=negative_prompt, num_inference_steps=50, embedded_guidance=3.5, - height=image_shape, - width=image_shape, - entity_prompts=entity_prompts, - entity_masks=masks, - inpaint_input=inpaint_input, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, + seed=0, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, + enable_eligen_inpaint=True, ) image.save(f"entity_inpaint.png") -visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") \ No newline at end of file +visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") From 8cf3422688da3f2329aa1cbaca178721270a3922 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 3 Jan 2025 10:37:34 +0800 Subject: [PATCH 8/8] update eligen ui --- .../{entity_level_control.py => eligen_ui.py} | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) rename apps/gradio/{entity_level_control.py => eligen_ui.py} (91%) diff --git a/apps/gradio/entity_level_control.py b/apps/gradio/eligen_ui.py similarity index 91% rename from apps/gradio/entity_level_control.py rename to apps/gradio/eligen_ui.py index d914cd0..4287ad0 100644 --- a/apps/gradio/entity_level_control.py +++ b/apps/gradio/eligen_ui.py @@ -1,5 +1,5 @@ import gradio as gr -from diffsynth import ModelManager, FluxImagePipeline +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models import os, torch from PIL import Image import numpy as np @@ -7,11 +7,8 @@ from PIL import ImageDraw, ImageFont import random import json -lora_checkpoint_path = 'models/lora/entity_control/model_bf16.safetensors' -save_masks_dir = 'workdirs/tmp_mask' - def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): - save_dir = os.path.join(save_masks_dir, random_dir) + save_dir = os.path.join('workdirs/tmp_mask', random_dir) os.makedirs(save_dir, exist_ok=True) for i, mask in enumerate(masks): save_path = os.path.join(save_dir, f'{i}.png') @@ -79,7 +76,7 @@ config = { "default_parameters": { "cfg_scale": 3.0, "embedded_guidance": 3.5, - "num_inference_steps": 30, + "num_inference_steps": 50, } }, }, @@ -109,17 +106,15 @@ def load_model(model_type, model_path): model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) model_manager = ModelManager() if model_type == "FLUX": - model_manager.torch_dtype = torch.bfloat16 - file_list = [ - os.path.join(model_path, "text_encoder/model.safetensors"), - os.path.join(model_path, "text_encoder_2"), - ] - for file_name in os.listdir(model_path): - if file_name.endswith(".safetensors"): - file_list.append(os.path.join(model_path, file_name)) - model_manager.load_models(file_list) - model_manager.load_lora(lora_checkpoint_path, lora_alpha=1.) - + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) + model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control", + ), + lora_alpha=1, + ) else: model_manager.load_model(model_path) pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) @@ -137,7 +132,7 @@ with gr.Blocks() as app: gr.Markdown(""" # 实体级控制文生图模型EliGen **UI说明** - 1. 点击Load model读取模型,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。 + 1. **点击Load model读取模型**,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。 2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。 3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。 4. 尽情创造! @@ -145,7 +140,7 @@ with gr.Blocks() as app: gr.Markdown(""" # Entity-Level Controlled Text-to-Image Model: EliGen **UI Instructions** - 1. Click "Load model" to load the model. The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. + 1. **Click "Load model" to load the model.** The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. 2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts. 3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing. 4. Enjoy! @@ -241,7 +236,6 @@ with gr.Blocks() as app: triggers=run_button.click ) def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): - print("save to:", random_mask_dir.value) _, pipe = load_model(model_type, model_path) input_params = { "prompt": prompt, @@ -255,7 +249,8 @@ with gr.Blocks() as app: if isinstance(pipe, FluxImagePipeline): input_params["embedded_guidance"] = embedded_guidance if input_image is not None: - input_params["inpaint_input"] = input_image.resize((width, height)).convert("RGB") + input_params["input_image"] = input_image.resize((width, height)).convert("RGB") + input_params["enable_eligen_inpaint"] = True enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], @@ -274,8 +269,8 @@ with gr.Blocks() as app: entity_masks = None if len(masks) == 0 else masks entity_prompts = None if len(local_prompts) == 0 else local_prompts input_params.update({ - "entity_prompts": entity_prompts, - "entity_masks": entity_masks, + "eligen_entity_prompts": entity_prompts, + "eligen_entity_masks": entity_masks, }) torch.manual_seed(seed) image = pipe(**input_params)