From e3d89cec0c8511e9e1e59508650355e0ce89c20a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 25 Dec 2024 17:19:31 +0800 Subject: [PATCH] temp commit for entity control --- diffsynth/models/flux_dit.py | 128 +++++++++++++----- diffsynth/pipelines/flux_image.py | 93 +++++++++++-- examples/EntityControl/entity_control_flux.py | 54 ++++++++ examples/EntityControl/entity_inpaint_flux.py | 59 ++++++++ examples/EntityControl/utils.py | 59 ++++++++ 5 files changed, 349 insertions(+), 44 deletions(-) create mode 100644 examples/EntityControl/entity_control_flux.py create mode 100644 examples/EntityControl/entity_inpaint_flux.py create mode 100644 examples/EntityControl/utils.py diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index faf58cd..8deea3b 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -40,7 +40,7 @@ class RoPEEmbedding(torch.nn.Module): n_axes = ids.shape[-1] emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) return emb.unsqueeze(1) - + class FluxJointAttention(torch.nn.Module): @@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states_a.shape[0] # Part A @@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] @@ -103,7 +103,7 @@ class FluxJointAttention(torch.nn.Module): else: hidden_states_b = self.b_to_out(hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxJointTransformerBlock(torch.nn.Module): @@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module): ) - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) # Attention - attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list) + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a @@ -147,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module): hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxSingleAttention(torch.nn.Module): @@ -184,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) return hidden_states - + class AdaLayerNormSingle(torch.nn.Module): @@ -200,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module): shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa - + class FluxSingleTransformerBlock(torch.nn.Module): @@ -225,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None): + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states.shape[0] qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) @@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) if ipadapter_kwargs_list is not None: @@ -243,21 +243,21 @@ class FluxSingleTransformerBlock(torch.nn.Module): return hidden_states - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): residual = hidden_states_a norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) hidden_states_a = self.to_qkv_mlp(norm_hidden_states) attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] - attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list) + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) hidden_states_a = residual + hidden_states_a - + return hidden_states_a, hidden_states_b - + class AdaLayerNormContinuous(torch.nn.Module): @@ -300,7 +300,7 @@ class FluxDiT(torch.nn.Module): def unpatchify(self, hidden_states, height, width): hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) return hidden_states - + def prepare_image_ids(self, latents): batch_size, _, height, width = latents.shape @@ -317,7 +317,7 @@ class FluxDiT(torch.nn.Module): latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) return latent_image_ids - + def tiled_forward( self, @@ -337,12 +337,45 @@ class FluxDiT(torch.nn.Module): ) return hidden_states + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask def forward( self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, - tiled=False, tile_size=128, tile_stride=64, + tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None, use_gradient_checkpointing=False, **kwargs ): @@ -353,46 +386,70 @@ class FluxDiT(torch.nn.Module): tile_size=tile_size, tile_stride=tile_stride, **kwargs ) - + if image_ids is None: image_ids = self.prepare_image_ids(hidden_states) - + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) if self.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = self.context_embedder(prompt_emb) - image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = self.patchify(hidden_states) hidden_states = self.x_embedder(hidden_states) - + + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - + for block in self.blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) for block in self.single_blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = hidden_states[:, prompt_emb.shape[1]:] hidden_states = self.final_norm_out(hidden_states, conditioning) @@ -400,7 +457,7 @@ class FluxDiT(torch.nn.Module): hidden_states = self.unpatchify(hidden_states, height, width) return hidden_states - + def quantize(self): def cast_to(weight, dtype=None, device=None, copy=False): @@ -440,16 +497,16 @@ class FluxDiT(torch.nn.Module): class Linear(torch.nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def forward(self,input,**kwargs): weight,bias= cast_bias_weight(self,input) return torch.nn.functional.linear(input,weight,bias) - + class RMSNorm(torch.nn.Module): def __init__(self, module): super().__init__() self.module = module - + def forward(self,hidden_states,**kwargs): weight= cast_weight(self.module,hidden_states) input_dtype = hidden_states.dtype @@ -457,7 +514,7 @@ class FluxDiT(torch.nn.Module): hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) hidden_states = hidden_states.to(input_dtype) * weight return hidden_states - + def replace_layer(model): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): @@ -483,7 +540,6 @@ class FluxDiT(torch.nn.Module): @staticmethod def state_dict_converter(): return FluxDiTStateDictConverter() - class FluxDiTStateDictConverter: @@ -587,7 +643,7 @@ class FluxDiTStateDictConverter: state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) return state_dict_ - + def from_civitai(self, state_dict): rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 42d142c..c012f09 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -139,6 +139,39 @@ class FluxImagePipeline(BasePipeline): images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] return torch.cat(images, dim=0) + def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.): + # inpaint noise + inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id] + # merge noise + weight = torch.ones_like(inpaint_noise) + inpaint_noise[fg_mask] = pred_noise[fg_mask] + inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight + weight[bg_mask] += background_weight + inpaint_noise /= weight + return inpaint_noise + + def preprocess_masks(self, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None): + fg_mask, bg_mask = None, None + if inpaint_input is not None: + from copy import deepcopy + masks_ = deepcopy(entity_masks) + fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_]) + fg_masks = (fg_masks > 0).float() + fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0 + bg_mask = ~fg_mask + entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0) + return entity_prompts, entity_masks, fg_mask, bg_mask + @torch.no_grad() def __call__( self, @@ -160,6 +193,10 @@ class FluxImagePipeline(BasePipeline): width=1024, num_inference_steps=30, t5_sequence_length=512, + inpaint_input=None, + entity_prompts=None, + entity_masks=None, + use_seperated_negtive_prompt=True, tiled=False, tile_size=128, tile_stride=64, @@ -176,12 +213,13 @@ class FluxImagePipeline(BasePipeline): self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if input_image is not None: + if input_image is not None or inpaint_input is not None: + input_image = input_image or inpaint_input self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) - latents = self.encode_image(image, **tiler_kwargs) + input_latents = self.encode_image(image, **tiler_kwargs) noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) - latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) @@ -195,6 +233,14 @@ class FluxImagePipeline(BasePipeline): prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] + # Entity control + negative_entity_prompts = None + negative_masks = None + if entity_masks is not None: + entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input) + if use_seperated_negtive_prompt and cfg_scale != 1.0: + negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1) + negative_masks = entity_masks # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) @@ -229,18 +275,20 @@ class FluxImagePipeline(BasePipeline): # Classifier-free guidance inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, + hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs ) + if inpaint_input: + noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id) if cfg_scale != 1.0: negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, + hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks, **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) @@ -278,6 +326,8 @@ def lets_dance_flux( tiled=False, tile_size=128, tile_stride=64, + entity_prompts=None, + entity_masks=None, ipadapter_kwargs_list={}, **kwargs ): @@ -333,13 +383,38 @@ def lets_dance_flux( if dit.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = dit.context_embedder(prompt_emb) - image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) hidden_states = dit.x_embedder(hidden_states) - + + # Entity Control + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + # Joint Blocks for block_id, block in enumerate(dit.blocks): hidden_states, prompt_emb = block( @@ -347,6 +422,7 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, + attention_mask, ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) # ControlNet if controlnet is not None and controlnet_frames is not None: @@ -361,6 +437,7 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, + attention_mask, ipadapter_kwargs_list=ipadapter_kwargs_list.get( block_id + num_joint_blocks, None)) # ControlNet diff --git a/examples/EntityControl/entity_control_flux.py b/examples/EntityControl/entity_control_flux.py new file mode 100644 index 0000000..5e74e0b --- /dev/null +++ b/examples/EntityControl/entity_control_flux.py @@ -0,0 +1,54 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image + +# lora_path = download_customized_models( +# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", +# origin_file_path="merged_lora.safetensors", +# local_dir="models/lora" +# )[0] + +lora_path = '/root/model_bf16.safetensors' +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", + "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", + "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +model_manager.load_lora(lora_path, lora_alpha=1.) + +pipe = FluxImagePipeline.from_model_manager(model_manager) + +mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' +image_shape = 1024 +guidance = 3.5 +cfg = 3.0 +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +names = ['row_2_1'] +seeds = [0] +# use this to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +for name, seed in zip(names, seeds): + out_dir = f'workdirs/entity_control/{name}' + os.makedirs(out_dir, exist_ok=True) + cur_dir = os.path.join(mask_dir, name) + metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) + for seed in range(3, 10): + prompt = metas['global_prompt'] + mask_prompts = metas['mask_prompts'] + masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] + torch.manual_seed(seed) + image = pipe( + prompt=prompt, + cfg_scale=cfg, + negative_prompt=negative_prompt, + num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape, + entity_prompts=mask_prompts, entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt + ) + use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' + visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/entity_inpaint_flux.py b/examples/EntityControl/entity_inpaint_flux.py new file mode 100644 index 0000000..5027563 --- /dev/null +++ b/examples/EntityControl/entity_inpaint_flux.py @@ -0,0 +1,59 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image + +# lora_path = download_customized_models( +# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", +# origin_file_path="merged_lora.safetensors", +# local_dir="models/lora" +# )[0] + +lora_path = '/root/model_bf16.safetensors' +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", + "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", + "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) +model_manager.load_lora(lora_path, lora_alpha=1.) + +pipe = FluxImagePipeline.from_model_manager(model_manager) + +mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' +image_shape = 1024 +guidance = 3.5 +cfg = 3.0 +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," +names = ['inpaint2'] +seeds = [0] +use_seperated_negtive_prompt = False +for name, seed in zip(names, seeds): + out_dir = f'workdirs/paper_app/inpaint/elc/{name}' + os.makedirs(out_dir, exist_ok=True) + cur_dir = os.path.join(mask_dir, name) + metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) + inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB') + prompt = metas['global_prompt'] + prompt = 'A person with a dog walking on the cloud. A rocket in the sky' + mask_prompts = metas['mask_prompts'] + masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] + torch.manual_seed(seed) + image = pipe( + prompt=prompt, + cfg_scale=cfg, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=guidance, + height=image_shape, + width=image_shape, + entity_prompts=mask_prompts, + entity_masks=masks, + inpaint_input=inpaint_input, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, + ) + use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' + visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/utils.py b/examples/EntityControl/utils.py new file mode 100644 index 0000000..573c8a6 --- /dev/null +++ b/examples/EntityControl/utils.py @@ -0,0 +1,59 @@ +from PIL import Image, ImageDraw, ImageFont +import random + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result