From 6f743fc4b6c39aa56a2ebed55a46f25fac48c585 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 2 Jan 2025 19:54:09 +0800 Subject: [PATCH] refine code --- diffsynth/models/flux_dit.py | 64 +++-- diffsynth/models/model_manager.py | 26 +- diffsynth/pipelines/flux_image.py | 240 ++++++++++-------- examples/EntityControl/entity_control.py | 56 ++-- .../EntityControl/entity_control_ipadapter.py | 61 ++--- examples/EntityControl/entity_inpaint.py | 63 ++--- 6 files changed, 263 insertions(+), 247 deletions(-) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 8deea3b..d592e61 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -337,6 +337,7 @@ class FluxDiT(torch.nn.Module): ) return hidden_states + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): N = len(entity_masks) batch_size = entity_masks[0].shape[0] @@ -371,11 +372,41 @@ class FluxDiT(torch.nn.Module): attention_mask[attention_mask == 1] = 0 return attention_mask + + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids): + repeat_dim = hidden_states.shape[1] + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + return prompt_emb, image_rotary_emb, attention_mask + + def forward( self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, - tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None, + tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None, use_gradient_checkpointing=False, **kwargs ): @@ -395,35 +426,16 @@ class FluxDiT(torch.nn.Module): guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) - repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = self.patchify(hidden_states) hidden_states = self.x_embedder(hidden_states) - max_masks = 0 - attention_mask = None - prompt_embs = [prompt_emb] - if entity_masks is not None: - # entity_masks - batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] - entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] - # global mask - global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) - entity_masks = entity_masks + [global_mask] # append global to last - # attention mask - attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) - attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) - attention_mask = attention_mask.unsqueeze(1) - # embds: n_masks * b * seq * d - local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] - prompt_embs = local_embs + prompt_embs # append global to last - prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] - prompt_emb = torch.cat(prompt_embs, dim=1) - - # positional embedding - text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) - image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + else: + prompt_emb = self.context_embedder(prompt_emb) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None def create_custom_forward(module): def custom_forward(*inputs): diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index dcee6d3..96ee86a 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -366,17 +366,21 @@ class ModelManager: def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): - print(f"Loading LoRA models from file: {file_path}") - if len(state_dict) == 0: - state_dict = load_state_dict(file_path) - for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): - for lora in get_lora_loaders(): - match_results = lora.match(model, state_dict) - if match_results is not None: - print(f" Adding LoRA to {model_name} ({model_path}).") - lora_prefix, model_resource = match_results - lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) - break + if isinstance(file_path, list): + for file_path_ in file_path: + self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) + else: + print(f"Loading LoRA models from file: {file_path}") + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): + for lora in get_lora_loaders(): + match_results = lora.match(model, state_dict) + if match_results is not None: + print(f" Adding LoRA to {model_name} ({model_path}).") + lora_prefix, model_resource = match_results + lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) + break def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index c012f09..b6fac68 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -10,6 +10,7 @@ import numpy as np from PIL import Image from ..models.tiler import FastTileWorker from transformers import SiglipVisionModel +from copy import deepcopy class FluxImagePipeline(BasePipeline): @@ -59,6 +60,7 @@ class FluxImagePipeline(BasePipeline): self.ipadapter = model_manager.fetch_model("flux_ipadapter") self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") + @staticmethod def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): pipe = FluxImagePipeline( @@ -133,12 +135,14 @@ class FluxImagePipeline(BasePipeline): # store it controlnet_frames.append(image) return controlnet_frames - + + def prepare_ipadapter_inputs(self, images, height=384, width=384): images = [image.convert("RGB").resize((width, height), resample=3) for image in images] images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] return torch.cat(images, dim=0) + def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.): # inpaint noise inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id] @@ -150,6 +154,7 @@ class FluxImagePipeline(BasePipeline): inpaint_noise /= weight return inpaint_noise + def preprocess_masks(self, masks, height, width, dim): out_masks = [] for mask in masks: @@ -158,10 +163,10 @@ class FluxImagePipeline(BasePipeline): out_masks.append(mask) return out_masks - def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None): + + def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False): fg_mask, bg_mask = None, None - if inpaint_input is not None: - from copy import deepcopy + if enable_eligen_inpaint: masks_ = deepcopy(entity_masks) fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_]) fg_masks = (fg_masks > 0).float() @@ -172,35 +177,114 @@ class FluxImagePipeline(BasePipeline): entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0) return entity_prompts, entity_masks, fg_mask, bg_mask + + def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride): + if input_image is not None: + self.load_models_to_device(['vae_encoder']) + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + input_latents = None + return latents, input_latents + + + def prepare_ipadapter(self, ipadapter_images, ipadapter_scale): + if ipadapter_images is not None: + self.load_models_to_device(['ipadapter_image_encoder']) + ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images) + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output + self.load_models_to_device(['ipadapter']) + ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} + ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} + return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega + + + def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative): + if controlnet_image is not None: + self.load_models_to_device(['vae_encoder']) + controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} + if len(masks) > 0 and controlnet_inpaint_mask is not None: + print("The controlnet_inpaint_mask will be overridden by masks.") + local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks] + else: + local_controlnet_kwargs = None + else: + controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks) + controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {} + return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs + + + def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale): + if eligen_entity_masks is not None: + entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, eligen_entity_masks.shape[1], 1, 1) + entity_masks_nega = eligen_entity_masks + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + else: + entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None + fg_mask, bg_mask = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask + + + def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale): + # Extend prompt + self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) + prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) + + # Encode prompts + prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None + prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] + return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals + + @torch.no_grad() def __call__( self, + # Prompt prompt, - local_prompts=None, - masks=None, - mask_scales=None, negative_prompt="", cfg_scale=1.0, embedded_guidance=3.5, + t5_sequence_length=512, + # Image input_image=None, - ipadapter_images=None, - ipadapter_scale=1.0, - controlnet_image=None, - controlnet_inpaint_mask=None, - enable_controlnet_on_negative=False, denoising_strength=1.0, height=1024, width=1024, + seed=None, + # Steps num_inference_steps=30, - t5_sequence_length=512, - inpaint_input=None, - entity_prompts=None, - entity_masks=None, - use_seperated_negtive_prompt=True, + # local prompts + local_prompts=(), + masks=(), + mask_scales=(), + # ControlNet + controlnet_image=None, + controlnet_inpaint_mask=None, + enable_controlnet_on_negative=False, + # IP-Adapter + ipadapter_images=None, + ipadapter_scale=1.0, + # EliGen + eligen_entity_prompts=None, + eligen_entity_masks=None, + enable_eligen_on_negative=False, + enable_eligen_inpaint=False, + # Tile tiled=False, tile_size=128, tile_stride=64, - seed=None, + # Progress bar progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -213,83 +297,50 @@ class FluxImagePipeline(BasePipeline): self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if input_image is not None or inpaint_input is not None: - input_image = input_image or inpaint_input - self.load_models_to_device(['vae_encoder']) - image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) - input_latents = self.encode_image(image, **tiler_kwargs) - noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) - latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) - else: - latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride) - # Extend prompt - self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) - prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) + # Prompt + prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale) - # Encode prompts - prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) - if cfg_scale != 1.0: - prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) - prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] - - # Entity control - negative_entity_prompts = None - negative_masks = None - if entity_masks is not None: - entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input) - if use_seperated_negtive_prompt and cfg_scale != 1.0: - negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1) - negative_masks = entity_masks # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) - # IP-Adapter - if ipadapter_images is not None: - self.load_models_to_device(['ipadapter_image_encoder']) - ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images) - ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output - self.load_models_to_device(['ipadapter']) - ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} - ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} - else: - ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} + # Entity control + eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale) - # Prepare ControlNets - if controlnet_image is not None: - self.load_models_to_device(['vae_encoder']) - controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} - if len(masks) > 0 and controlnet_inpaint_mask is not None: - print("The controlnet_inpaint_mask will be overridden by masks.") - local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks] - else: - local_controlnet_kwargs = None - else: - controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks) + # IP-Adapter + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale) + + # ControlNets + controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) # Denoise self.load_models_to_device(['dit', 'controlnet']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) - # Classifier-free guidance + # Positive side inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, + hidden_states=latents, timestep=timestep, + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, - special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs + special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs ) - if inpaint_input: + + # Inpaint + if enable_eligen_inpaint: noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id) + + # Classifier-free guidance if cfg_scale != 1.0: - negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} + # Negative side noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, - hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, + hidden_states=latents, timestep=timestep, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -304,7 +355,7 @@ class FluxImagePipeline(BasePipeline): # Decode image self.load_models_to_device(['vae_decoder']) - image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.decode_image(latents, **tiler_kwargs) # Offload all models self.load_models_to_device([]) @@ -326,7 +377,7 @@ def lets_dance_flux( tiled=False, tile_size=128, tile_stride=64, - entity_prompts=None, + entity_prompt_emb=None, entity_masks=None, ipadapter_kwargs_list={}, **kwargs @@ -384,36 +435,16 @@ def lets_dance_flux( guidance = guidance * 1000 conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) - repeat_dim = hidden_states.shape[1] height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) hidden_states = dit.x_embedder(hidden_states) - # Entity Control - max_masks = 0 - attention_mask = None - prompt_embs = [prompt_emb] - if entity_masks is not None: - # entity_masks - batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] - entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] - # global mask - global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) - entity_masks = entity_masks + [global_mask] # append global to last - # attention mask - attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) - attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) - attention_mask = attention_mask.unsqueeze(1) - # embds: n_masks * b * seq * d - local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)] - prompt_embs = local_embs + prompt_embs # append global to last - prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs] - prompt_emb = torch.cat(prompt_embs, dim=1) - - # positional embedding - text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) - image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None # Joint Blocks for block_id, block in enumerate(dit.blocks): @@ -423,7 +454,8 @@ def lets_dance_flux( conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states = hidden_states + controlnet_res_stack[block_id] @@ -438,8 +470,8 @@ def lets_dance_flux( conditioning, image_rotary_emb, attention_mask, - ipadapter_kwargs_list=ipadapter_kwargs_list.get( - block_id + num_joint_blocks, None)) + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] diff --git a/examples/EntityControl/entity_control.py b/examples/EntityControl/entity_control.py index 2735682..f505b94 100644 --- a/examples/EntityControl/entity_control.py +++ b/examples/EntityControl/entity_control.py @@ -1,57 +1,43 @@ -import torch from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download from examples.EntityControl.utils import visualize_masks from PIL import Image -import requests -from io import BytesIO +import torch + # download and load model -lora_path = download_customized_models( - model_id="DiffSynth-Studio/Eligen", - origin_file_path="model_bf16.safetensors", - local_dir="models/lora/entity_control" -)[0] model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) -model_manager.load_lora(lora_path, lora_alpha=1.) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) pipe = FluxImagePipeline.from_model_manager(model_manager) -# prepare inputs -image_shape = 1024 -seed = 4 -# set True to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -mask_urls = [ - 'https://github.com/user-attachments/assets/02905f6e-40c2-4482-9abe-b1ce50ccabbf', - 'https://github.com/user-attachments/assets/a4cf4361-abf7-4556-ba94-74683eda4cb7', - 'https://github.com/user-attachments/assets/b6595ff4-7269-4d8f-acf0-5df40bd6c59f', - 'https://github.com/user-attachments/assets/941d39a7-3aa1-437f-8b2a-4adb15d2fb3e', - 'https://github.com/user-attachments/assets/400c4086-5398-4291-b1b5-22d8483c08d9', - 'https://github.com/user-attachments/assets/ce324c77-fa1d-4aad-a5cb-698f0d5eca70', - 'https://github.com/user-attachments/assets/4e62325f-a60c-44f7-b53b-6da0869bb9db' -] -# prepare entity masks, entity prompts, global prompt and negative prompt -masks = [] -for url in mask_urls: - response = requests.get(url) - mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) - masks.append(mask) +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/mask*") +masks = [Image.open(f"./data/examples/eligen/mask{i}.png") for i in range(1, 8)] + entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" # generate image -torch.manual_seed(seed) image = pipe( prompt=global_prompt, cfg_scale=3.0, negative_prompt=negative_prompt, num_inference_steps=50, embedded_guidance=3.5, - height=image_shape, - width=image_shape, - entity_prompts=entity_prompts, - entity_masks=masks, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, + seed=4, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, ) image.save(f"entity_control.png") visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_control_ipadapter.py b/examples/EntityControl/entity_control_ipadapter.py index 76ae869..c604bad 100644 --- a/examples/EntityControl/entity_control_ipadapter.py +++ b/examples/EntityControl/entity_control_ipadapter.py @@ -1,51 +1,46 @@ -import torch from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download from examples.EntityControl.utils import visualize_masks from PIL import Image -import requests -from io import BytesIO +import torch -lora_path = download_customized_models( - model_id="DiffSynth-Studio/Eligen", - origin_file_path="model_bf16.safetensors", - local_dir="models/lora/entity_control" -)[0] + +# download and load model model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"]) -model_manager.load_lora(lora_path, lora_alpha=1.) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) pipe = FluxImagePipeline.from_model_manager(model_manager) -# prepare inputs -image_shape = 1024 -seed = 4 -# set True to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -mask_urls = [ - 'https://github.com/user-attachments/assets/e6745b3f-ab2b-4612-9bb5-b7235474a9a4', - 'https://github.com/user-attachments/assets/5ddf9a89-32fa-4540-89ad-e956130942b3', - 'https://github.com/user-attachments/assets/9d8a0bb0-6817-497e-af85-44f2512afe79' -] -# prepare entity masks, entity prompts, global prompt and negative prompt -masks = [] -for url in mask_urls: - response = requests.get(url) - mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) - masks.append(mask) +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/ipadapter*") +masks = [Image.open(f"./data/examples/eligen/ipadapter_mask_{i}.png") for i in range(1, 4)] + entity_prompts = ['A girl', 'hat', 'sunset'] global_prompt = "A girl wearing a hat, looking at the sunset" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" +reference_img = Image.open("./data/examples/eligen/ipadapter_image.png") -response = requests.get('https://github.com/user-attachments/assets/019bbfaa-04b3-4de6-badb-32b67c29a1bc') -reference_img = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) - -torch.manual_seed(seed) +# generate image image = pipe( prompt=global_prompt, cfg_scale=3.0, negative_prompt=negative_prompt, - num_inference_steps=50, embedded_guidance=3.5, height=image_shape, width=image_shape, - entity_prompts=entity_prompts, entity_masks=masks, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, - ipadapter_images=[reference_img], ipadapter_scale=0.7 + num_inference_steps=50, + embedded_guidance=3.5, + seed=4, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, + ipadapter_images=[reference_img], + ipadapter_scale=0.7 ) image.save(f"styled_entity_control.png") visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_inpaint.py b/examples/EntityControl/entity_inpaint.py index eae4641..d62da4f 100644 --- a/examples/EntityControl/entity_inpaint.py +++ b/examples/EntityControl/entity_inpaint.py @@ -1,58 +1,45 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download from examples.EntityControl.utils import visualize_masks -import os -import json from PIL import Image -import requests -from io import BytesIO +import torch # download and load model -lora_path = download_customized_models( - model_id="DiffSynth-Studio/Eligen", - origin_file_path="model_bf16.safetensors", - local_dir="models/lora/entity_control" -)[0] model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) -model_manager.load_lora(lora_path, lora_alpha=1.) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) pipe = FluxImagePipeline.from_model_manager(model_manager) -# prepare inputs -image_shape = 1024 -seed = 0 -# set True to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -mask_urls = [ - 'https://github.com/user-attachments/assets/0cf78663-5314-4280-a065-31ded7a24a46', - 'https://github.com/user-attachments/assets/bd3938b8-72a8-4d56-814f-f6445971b91d' -] -# prepare entity masks, entity prompts, global prompt and negative prompt -masks = [] -for url in mask_urls: - response = requests.get(url) - mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) - masks.append(mask) +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/inpaint*") +masks = [Image.open(f"./data/examples/eligen/inpaint_mask_{i}.png") for i in range(1, 3)] +input_image = Image.open("./data/examples/eligen/inpaint_image.jpg") + entity_prompts = ["A person wear red shirt", "Airplane"] global_prompt = "A person walking on the path in front of a house; An airplane in the sky" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur" -response = requests.get('https://github.com/user-attachments/assets/fa4d6ba5-08fd-4fc7-adbb-19898d839364') -inpaint_input = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) - # generate image -torch.manual_seed(seed) image = pipe( prompt=global_prompt, + input_image=input_image, cfg_scale=3.0, negative_prompt=negative_prompt, num_inference_steps=50, embedded_guidance=3.5, - height=image_shape, - width=image_shape, - entity_prompts=entity_prompts, - entity_masks=masks, - inpaint_input=inpaint_input, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, + seed=0, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, + enable_eligen_inpaint=True, ) image.save(f"entity_inpaint.png") -visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") \ No newline at end of file +visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png")