temp commit for entity control

This commit is contained in:
mi804
2024-12-25 17:19:31 +08:00
parent 1b6e96a820
commit e3d89cec0c
5 changed files with 349 additions and 44 deletions

View File

@@ -139,6 +139,39 @@ class FluxImagePipeline(BasePipeline):
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0)
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
# inpaint noise
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
# merge noise
weight = torch.ones_like(inpaint_noise)
inpaint_noise[fg_mask] = pred_noise[fg_mask]
inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
weight[bg_mask] += background_weight
inpaint_noise /= weight
return inpaint_noise
def preprocess_masks(self, masks, height, width, dim):
out_masks = []
for mask in masks:
mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
out_masks.append(mask)
return out_masks
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None):
fg_mask, bg_mask = None, None
if inpaint_input is not None:
from copy import deepcopy
masks_ = deepcopy(entity_masks)
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
fg_masks = (fg_masks > 0).float()
fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
bg_mask = ~fg_mask
entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
return entity_prompts, entity_masks, fg_mask, bg_mask
@torch.no_grad()
def __call__(
self,
@@ -160,6 +193,10 @@ class FluxImagePipeline(BasePipeline):
width=1024,
num_inference_steps=30,
t5_sequence_length=512,
inpaint_input=None,
entity_prompts=None,
entity_masks=None,
use_seperated_negtive_prompt=True,
tiled=False,
tile_size=128,
tile_stride=64,
@@ -176,12 +213,13 @@ class FluxImagePipeline(BasePipeline):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
if input_image is not None or inpaint_input is not None:
input_image = input_image or inpaint_input
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
input_latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
@@ -195,6 +233,14 @@ class FluxImagePipeline(BasePipeline):
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Entity control
negative_entity_prompts = None
negative_masks = None
if entity_masks is not None:
entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input)
if use_seperated_negtive_prompt and cfg_scale != 1.0:
negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
negative_masks = entity_masks
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
@@ -229,18 +275,20 @@ class FluxImagePipeline(BasePipeline):
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs
)
if inpaint_input:
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
if cfg_scale != 1.0:
negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {}
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
@@ -278,6 +326,8 @@ def lets_dance_flux(
tiled=False,
tile_size=128,
tile_stride=64,
entity_prompts=None,
entity_masks=None,
ipadapter_kwargs_list={},
**kwargs
):
@@ -333,13 +383,38 @@ def lets_dance_flux(
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
repeat_dim = hidden_states.shape[1]
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
# Entity Control
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
@@ -347,6 +422,7 @@ def lets_dance_flux(
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
# ControlNet
if controlnet is not None and controlnet_frames is not None:
@@ -361,6 +437,7 @@ def lets_dance_flux(
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
block_id + num_joint_blocks, None))
# ControlNet