refine code

This commit is contained in:
Artiprocher
2025-01-02 19:54:09 +08:00
parent 2872fdaf48
commit 6f743fc4b6
6 changed files with 263 additions and 247 deletions

View File

@@ -337,6 +337,7 @@ class FluxDiT(torch.nn.Module):
) )
return hidden_states return hidden_states
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
N = len(entity_masks) N = len(entity_masks)
batch_size = entity_masks[0].shape[0] batch_size = entity_masks[0].shape[0]
@@ -371,11 +372,41 @@ class FluxDiT(torch.nn.Module):
attention_mask[attention_mask == 1] = 0 attention_mask[attention_mask == 1] = 0
return attention_mask return attention_mask
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
repeat_dim = hidden_states.shape[1]
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
return prompt_emb, image_rotary_emb, attention_mask
def forward( def forward(
self, self,
hidden_states, hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None, tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
**kwargs **kwargs
): ):
@@ -395,35 +426,16 @@ class FluxDiT(torch.nn.Module):
guidance = guidance * 1000 guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
repeat_dim = hidden_states.shape[1]
height, width = hidden_states.shape[-2:] height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states) hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
max_masks = 0 if entity_prompt_emb is not None and entity_masks is not None:
attention_mask = None prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
prompt_embs = [prompt_emb] else:
if entity_masks is not None: prompt_emb = self.context_embedder(prompt_emb)
# entity_masks image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] attention_mask = None
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -366,17 +366,21 @@ class ModelManager:
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
print(f"Loading LoRA models from file: {file_path}") if isinstance(file_path, list):
if len(state_dict) == 0: for file_path_ in file_path:
state_dict = load_state_dict(file_path) self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): else:
for lora in get_lora_loaders(): print(f"Loading LoRA models from file: {file_path}")
match_results = lora.match(model, state_dict) if len(state_dict) == 0:
if match_results is not None: state_dict = load_state_dict(file_path)
print(f" Adding LoRA to {model_name} ({model_path}).") for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
lora_prefix, model_resource = match_results for lora in get_lora_loaders():
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) match_results = lora.match(model, state_dict)
break if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):

View File

@@ -10,6 +10,7 @@ import numpy as np
from PIL import Image from PIL import Image
from ..models.tiler import FastTileWorker from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel from transformers import SiglipVisionModel
from copy import deepcopy
class FluxImagePipeline(BasePipeline): class FluxImagePipeline(BasePipeline):
@@ -59,6 +60,7 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter = model_manager.fetch_model("flux_ipadapter") self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline( pipe = FluxImagePipeline(
@@ -134,11 +136,13 @@ class FluxImagePipeline(BasePipeline):
controlnet_frames.append(image) controlnet_frames.append(image)
return controlnet_frames return controlnet_frames
def prepare_ipadapter_inputs(self, images, height=384, width=384): def prepare_ipadapter_inputs(self, images, height=384, width=384):
images = [image.convert("RGB").resize((width, height), resample=3) for image in images] images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0) return torch.cat(images, dim=0)
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.): def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
# inpaint noise # inpaint noise
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id] inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
@@ -150,6 +154,7 @@ class FluxImagePipeline(BasePipeline):
inpaint_noise /= weight inpaint_noise /= weight
return inpaint_noise return inpaint_noise
def preprocess_masks(self, masks, height, width, dim): def preprocess_masks(self, masks, height, width, dim):
out_masks = [] out_masks = []
for mask in masks: for mask in masks:
@@ -158,10 +163,10 @@ class FluxImagePipeline(BasePipeline):
out_masks.append(mask) out_masks.append(mask)
return out_masks return out_masks
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None):
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False):
fg_mask, bg_mask = None, None fg_mask, bg_mask = None, None
if inpaint_input is not None: if enable_eligen_inpaint:
from copy import deepcopy
masks_ = deepcopy(entity_masks) masks_ = deepcopy(entity_masks)
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_]) fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
fg_masks = (fg_masks > 0).float() fg_masks = (fg_masks > 0).float()
@@ -172,35 +177,114 @@ class FluxImagePipeline(BasePipeline):
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0) entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
return entity_prompts, entity_masks, fg_mask, bg_mask return entity_prompts, entity_masks, fg_mask, bg_mask
def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride):
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
input_latents = None
return latents, input_latents
def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega
def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative):
if controlnet_image is not None:
self.load_models_to_device(['vae_encoder'])
controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
if len(masks) > 0 and controlnet_inpaint_mask is not None:
print("The controlnet_inpaint_mask will be overridden by masks.")
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
else:
local_controlnet_kwargs = None
else:
controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {}
return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs
def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale):
if eligen_entity_masks is not None:
entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint)
if enable_eligen_on_negative and cfg_scale != 1.0:
entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, eligen_entity_masks.shape[1], 1, 1)
entity_masks_nega = eligen_entity_masks
else:
entity_prompt_emb_nega, entity_masks_nega = None, None
else:
entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None
fg_mask, bg_mask = None, None
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
# Prompt
prompt, prompt,
local_prompts=None,
masks=None,
mask_scales=None,
negative_prompt="", negative_prompt="",
cfg_scale=1.0, cfg_scale=1.0,
embedded_guidance=3.5, embedded_guidance=3.5,
t5_sequence_length=512,
# Image
input_image=None, input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
denoising_strength=1.0, denoising_strength=1.0,
height=1024, height=1024,
width=1024, width=1024,
seed=None,
# Steps
num_inference_steps=30, num_inference_steps=30,
t5_sequence_length=512, # local prompts
inpaint_input=None, local_prompts=(),
entity_prompts=None, masks=(),
entity_masks=None, mask_scales=(),
use_seperated_negtive_prompt=True, # ControlNet
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
# EliGen
eligen_entity_prompts=None,
eligen_entity_masks=None,
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# Tile
tiled=False, tiled=False,
tile_size=128, tile_size=128,
tile_stride=64, tile_stride=64,
seed=None, # Progress bar
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
@@ -213,83 +297,50 @@ class FluxImagePipeline(BasePipeline):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors # Prepare latent tensors
if input_image is not None or inpaint_input is not None: latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
input_image = input_image or inpaint_input
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
input_latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Extend prompt # Prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Entity control
negative_entity_prompts = None
negative_masks = None
if entity_masks is not None:
entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input)
if use_seperated_negtive_prompt and cfg_scale != 1.0:
negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
negative_masks = entity_masks
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# IP-Adapter # Entity control
if ipadapter_images is not None: eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets # IP-Adapter
if controlnet_image is not None: ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
self.load_models_to_device(['vae_encoder'])
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} # ControlNets
if len(masks) > 0 and controlnet_inpaint_mask is not None: controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
print("The controlnet_inpaint_mask will be overridden by masks.")
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
else:
local_controlnet_kwargs = None
else:
controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
# Denoise # Denoise
self.load_models_to_device(['dit', 'controlnet']) self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance # Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks, hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi,
) )
noise_pred_posi = self.control_noise_via_local_prompts( noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
) )
if inpaint_input:
# Inpaint
if enable_eligen_inpaint:
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id) noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
# Classifier-free guidance
if cfg_scale != 1.0: if cfg_scale != 1.0:
negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} # Negative side
noise_pred_nega = lets_dance_flux( noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks, hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -304,7 +355,7 @@ class FluxImagePipeline(BasePipeline):
# Decode image # Decode image
self.load_models_to_device(['vae_decoder']) self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) image = self.decode_image(latents, **tiler_kwargs)
# Offload all models # Offload all models
self.load_models_to_device([]) self.load_models_to_device([])
@@ -326,7 +377,7 @@ def lets_dance_flux(
tiled=False, tiled=False,
tile_size=128, tile_size=128,
tile_stride=64, tile_stride=64,
entity_prompts=None, entity_prompt_emb=None,
entity_masks=None, entity_masks=None,
ipadapter_kwargs_list={}, ipadapter_kwargs_list={},
**kwargs **kwargs
@@ -384,36 +435,16 @@ def lets_dance_flux(
guidance = guidance * 1000 guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
repeat_dim = hidden_states.shape[1]
height, width = hidden_states.shape[-2:] height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states) hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states) hidden_states = dit.x_embedder(hidden_states)
# Entity Control if entity_prompt_emb is not None and entity_masks is not None:
max_masks = 0 prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
attention_mask = None else:
prompt_embs = [prompt_emb] prompt_emb = dit.context_embedder(prompt_emb)
if entity_masks is not None: image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
# entity_masks attention_mask = None
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
# Joint Blocks # Joint Blocks
for block_id, block in enumerate(dit.blocks): for block_id, block in enumerate(dit.blocks):
@@ -423,7 +454,8 @@ def lets_dance_flux(
conditioning, conditioning,
image_rotary_emb, image_rotary_emb,
attention_mask, attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id] hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -438,8 +470,8 @@ def lets_dance_flux(
conditioning, conditioning,
image_rotary_emb, image_rotary_emb,
attention_mask, attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get( ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
block_id + num_joint_blocks, None)) )
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]

View File

@@ -1,57 +1,43 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
from examples.EntityControl.utils import visualize_masks from examples.EntityControl.utils import visualize_masks
from PIL import Image from PIL import Image
import requests import torch
from io import BytesIO
# download and load model # download and load model
lora_path = download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
)[0]
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(lora_path, lora_alpha=1.) model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
),
lora_alpha=1
)
pipe = FluxImagePipeline.from_model_manager(model_manager) pipe = FluxImagePipeline.from_model_manager(model_manager)
# prepare inputs # download and load mask images
image_shape = 1024 dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/mask*")
seed = 4 masks = [Image.open(f"./data/examples/eligen/mask{i}.png") for i in range(1, 8)]
# set True to apply regional attention in negative prompt prediction for better results with more time
use_seperated_negtive_prompt = False
mask_urls = [
'https://github.com/user-attachments/assets/02905f6e-40c2-4482-9abe-b1ce50ccabbf',
'https://github.com/user-attachments/assets/a4cf4361-abf7-4556-ba94-74683eda4cb7',
'https://github.com/user-attachments/assets/b6595ff4-7269-4d8f-acf0-5df40bd6c59f',
'https://github.com/user-attachments/assets/941d39a7-3aa1-437f-8b2a-4adb15d2fb3e',
'https://github.com/user-attachments/assets/400c4086-5398-4291-b1b5-22d8483c08d9',
'https://github.com/user-attachments/assets/ce324c77-fa1d-4aad-a5cb-698f0d5eca70',
'https://github.com/user-attachments/assets/4e62325f-a60c-44f7-b53b-6da0869bb9db'
]
# prepare entity masks, entity prompts, global prompt and negative prompt
masks = []
for url in mask_urls:
response = requests.get(url)
mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST)
masks.append(mask)
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw"
# generate image # generate image
torch.manual_seed(seed)
image = pipe( image = pipe(
prompt=global_prompt, prompt=global_prompt,
cfg_scale=3.0, cfg_scale=3.0,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
num_inference_steps=50, num_inference_steps=50,
embedded_guidance=3.5, embedded_guidance=3.5,
height=image_shape, seed=4,
width=image_shape, height=1024,
entity_prompts=entity_prompts, width=1024,
entity_masks=masks, eligen_entity_prompts=entity_prompts,
use_seperated_negtive_prompt=use_seperated_negtive_prompt, eligen_entity_masks=masks,
enable_eligen_on_negative=False,
) )
image.save(f"entity_control.png") image.save(f"entity_control.png")
visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png") visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png")

View File

@@ -1,51 +1,46 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
from examples.EntityControl.utils import visualize_masks from examples.EntityControl.utils import visualize_masks
from PIL import Image from PIL import Image
import requests import torch
from io import BytesIO
lora_path = download_customized_models(
model_id="DiffSynth-Studio/Eligen", # download and load model
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
)[0]
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"]) model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"])
model_manager.load_lora(lora_path, lora_alpha=1.) model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
),
lora_alpha=1
)
pipe = FluxImagePipeline.from_model_manager(model_manager) pipe = FluxImagePipeline.from_model_manager(model_manager)
# prepare inputs # download and load mask images
image_shape = 1024 dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/ipadapter*")
seed = 4 masks = [Image.open(f"./data/examples/eligen/ipadapter_mask_{i}.png") for i in range(1, 4)]
# set True to apply regional attention in negative prompt prediction for better results with more time
use_seperated_negtive_prompt = False
mask_urls = [
'https://github.com/user-attachments/assets/e6745b3f-ab2b-4612-9bb5-b7235474a9a4',
'https://github.com/user-attachments/assets/5ddf9a89-32fa-4540-89ad-e956130942b3',
'https://github.com/user-attachments/assets/9d8a0bb0-6817-497e-af85-44f2512afe79'
]
# prepare entity masks, entity prompts, global prompt and negative prompt
masks = []
for url in mask_urls:
response = requests.get(url)
mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST)
masks.append(mask)
entity_prompts = ['A girl', 'hat', 'sunset'] entity_prompts = ['A girl', 'hat', 'sunset']
global_prompt = "A girl wearing a hat, looking at the sunset" global_prompt = "A girl wearing a hat, looking at the sunset"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw"
reference_img = Image.open("./data/examples/eligen/ipadapter_image.png")
response = requests.get('https://github.com/user-attachments/assets/019bbfaa-04b3-4de6-badb-32b67c29a1bc') # generate image
reference_img = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape))
torch.manual_seed(seed)
image = pipe( image = pipe(
prompt=global_prompt, prompt=global_prompt,
cfg_scale=3.0, cfg_scale=3.0,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
num_inference_steps=50, embedded_guidance=3.5, height=image_shape, width=image_shape, num_inference_steps=50,
entity_prompts=entity_prompts, entity_masks=masks, embedded_guidance=3.5,
use_seperated_negtive_prompt=use_seperated_negtive_prompt, seed=4,
ipadapter_images=[reference_img], ipadapter_scale=0.7 height=1024,
width=1024,
eligen_entity_prompts=entity_prompts,
eligen_entity_masks=masks,
enable_eligen_on_negative=False,
ipadapter_images=[reference_img],
ipadapter_scale=0.7
) )
image.save(f"styled_entity_control.png") image.save(f"styled_entity_control.png")
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png") visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png")

View File

@@ -1,58 +1,45 @@
import torch from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline from modelscope import dataset_snapshot_download
from examples.EntityControl.utils import visualize_masks from examples.EntityControl.utils import visualize_masks
import os
import json
from PIL import Image from PIL import Image
import requests import torch
from io import BytesIO
# download and load model # download and load model
lora_path = download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
)[0]
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(lora_path, lora_alpha=1.) model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
),
lora_alpha=1
)
pipe = FluxImagePipeline.from_model_manager(model_manager) pipe = FluxImagePipeline.from_model_manager(model_manager)
# prepare inputs # download and load mask images
image_shape = 1024 dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/inpaint*")
seed = 0 masks = [Image.open(f"./data/examples/eligen/inpaint_mask_{i}.png") for i in range(1, 3)]
# set True to apply regional attention in negative prompt prediction for better results with more time input_image = Image.open("./data/examples/eligen/inpaint_image.jpg")
use_seperated_negtive_prompt = False
mask_urls = [
'https://github.com/user-attachments/assets/0cf78663-5314-4280-a065-31ded7a24a46',
'https://github.com/user-attachments/assets/bd3938b8-72a8-4d56-814f-f6445971b91d'
]
# prepare entity masks, entity prompts, global prompt and negative prompt
masks = []
for url in mask_urls:
response = requests.get(url)
mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST)
masks.append(mask)
entity_prompts = ["A person wear red shirt", "Airplane"] entity_prompts = ["A person wear red shirt", "Airplane"]
global_prompt = "A person walking on the path in front of a house; An airplane in the sky" global_prompt = "A person walking on the path in front of a house; An airplane in the sky"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur" negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur"
response = requests.get('https://github.com/user-attachments/assets/fa4d6ba5-08fd-4fc7-adbb-19898d839364')
inpaint_input = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape))
# generate image # generate image
torch.manual_seed(seed)
image = pipe( image = pipe(
prompt=global_prompt, prompt=global_prompt,
input_image=input_image,
cfg_scale=3.0, cfg_scale=3.0,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
num_inference_steps=50, num_inference_steps=50,
embedded_guidance=3.5, embedded_guidance=3.5,
height=image_shape, seed=0,
width=image_shape, height=1024,
entity_prompts=entity_prompts, width=1024,
entity_masks=masks, eligen_entity_prompts=entity_prompts,
inpaint_input=inpaint_input, eligen_entity_masks=masks,
use_seperated_negtive_prompt=use_seperated_negtive_prompt, enable_eligen_on_negative=False,
enable_eligen_inpaint=True,
) )
image.save(f"entity_inpaint.png") image.save(f"entity_inpaint.png")
visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png")