mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
temp commit for entity control
This commit is contained in:
@@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module):
|
|||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
|
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
batch_size = hidden_states_a.shape[0]
|
batch_size = hidden_states_a.shape[0]
|
||||||
|
|
||||||
# Part A
|
# Part A
|
||||||
@@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module):
|
|||||||
|
|
||||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||||
@@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
|
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||||
|
|
||||||
# Attention
|
# Attention
|
||||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
|
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||||
|
|
||||||
# Part A
|
# Part A
|
||||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||||
@@ -226,7 +226,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
|||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
|
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
@@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
|||||||
|
|
||||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
if ipadapter_kwargs_list is not None:
|
if ipadapter_kwargs_list is not None:
|
||||||
@@ -243,13 +243,13 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
|
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||||
residual = hidden_states_a
|
residual = hidden_states_a
|
||||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||||
|
|
||||||
attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
|
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||||
|
|
||||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||||
@@ -337,12 +337,45 @@ class FluxDiT(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||||
|
N = len(entity_masks)
|
||||||
|
batch_size = entity_masks[0].shape[0]
|
||||||
|
total_seq_len = N * prompt_seq_len + image_seq_len
|
||||||
|
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
||||||
|
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||||
|
|
||||||
|
image_start = N * prompt_seq_len
|
||||||
|
image_end = N * prompt_seq_len + image_seq_len
|
||||||
|
# prompt-image mask
|
||||||
|
for i in range(N):
|
||||||
|
prompt_start = i * prompt_seq_len
|
||||||
|
prompt_end = (i + 1) * prompt_seq_len
|
||||||
|
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||||
|
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
||||||
|
# prompt update with image
|
||||||
|
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||||
|
# image update with prompt
|
||||||
|
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||||
|
# prompt-prompt mask
|
||||||
|
for i in range(N):
|
||||||
|
for j in range(N):
|
||||||
|
if i != j:
|
||||||
|
prompt_start_i = i * prompt_seq_len
|
||||||
|
prompt_end_i = (i + 1) * prompt_seq_len
|
||||||
|
prompt_start_j = j * prompt_seq_len
|
||||||
|
prompt_end_j = (j + 1) * prompt_seq_len
|
||||||
|
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
||||||
|
|
||||||
|
attention_mask = attention_mask.float()
|
||||||
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
|
attention_mask[attention_mask == 1] = 0
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||||
tiled=False, tile_size=128, tile_stride=64,
|
tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None,
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -361,13 +394,37 @@ class FluxDiT(torch.nn.Module):
|
|||||||
if self.guidance_embedder is not None:
|
if self.guidance_embedder is not None:
|
||||||
guidance = guidance * 1000
|
guidance = guidance * 1000
|
||||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||||
prompt_emb = self.context_embedder(prompt_emb)
|
|
||||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
||||||
|
|
||||||
|
repeat_dim = hidden_states.shape[1]
|
||||||
height, width = hidden_states.shape[-2:]
|
height, width = hidden_states.shape[-2:]
|
||||||
hidden_states = self.patchify(hidden_states)
|
hidden_states = self.patchify(hidden_states)
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
|
||||||
|
max_masks = 0
|
||||||
|
attention_mask = None
|
||||||
|
prompt_embs = [prompt_emb]
|
||||||
|
if entity_masks is not None:
|
||||||
|
# entity_masks
|
||||||
|
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||||
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
# global mask
|
||||||
|
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
entity_masks = entity_masks + [global_mask] # append global to last
|
||||||
|
# attention mask
|
||||||
|
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||||
|
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
# embds: n_masks * b * seq * d
|
||||||
|
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
prompt_embs = local_embs + prompt_embs # append global to last
|
||||||
|
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||||
|
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||||
|
|
||||||
|
# positional embedding
|
||||||
|
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||||
|
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs)
|
return module(*inputs)
|
||||||
@@ -377,22 +434,22 @@ class FluxDiT(torch.nn.Module):
|
|||||||
if self.training and use_gradient_checkpointing:
|
if self.training and use_gradient_checkpointing:
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(block),
|
create_custom_forward(block),
|
||||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||||
use_reentrant=False,
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||||
|
|
||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
for block in self.single_blocks:
|
for block in self.single_blocks:
|
||||||
if self.training and use_gradient_checkpointing:
|
if self.training and use_gradient_checkpointing:
|
||||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(block),
|
create_custom_forward(block),
|
||||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||||
use_reentrant=False,
|
use_reentrant=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||||
|
|
||||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||||
@@ -485,7 +542,6 @@ class FluxDiT(torch.nn.Module):
|
|||||||
return FluxDiTStateDictConverter()
|
return FluxDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FluxDiTStateDictConverter:
|
class FluxDiTStateDictConverter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -139,6 +139,39 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
|
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
|
||||||
return torch.cat(images, dim=0)
|
return torch.cat(images, dim=0)
|
||||||
|
|
||||||
|
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
|
||||||
|
# inpaint noise
|
||||||
|
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
|
||||||
|
# merge noise
|
||||||
|
weight = torch.ones_like(inpaint_noise)
|
||||||
|
inpaint_noise[fg_mask] = pred_noise[fg_mask]
|
||||||
|
inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
|
||||||
|
weight[bg_mask] += background_weight
|
||||||
|
inpaint_noise /= weight
|
||||||
|
return inpaint_noise
|
||||||
|
|
||||||
|
def preprocess_masks(self, masks, height, width, dim):
|
||||||
|
out_masks = []
|
||||||
|
for mask in masks:
|
||||||
|
mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
||||||
|
mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
out_masks.append(mask)
|
||||||
|
return out_masks
|
||||||
|
|
||||||
|
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None):
|
||||||
|
fg_mask, bg_mask = None, None
|
||||||
|
if inpaint_input is not None:
|
||||||
|
from copy import deepcopy
|
||||||
|
masks_ = deepcopy(entity_masks)
|
||||||
|
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
|
||||||
|
fg_masks = (fg_masks > 0).float()
|
||||||
|
fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
|
||||||
|
bg_mask = ~fg_mask
|
||||||
|
entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
|
||||||
|
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
||||||
|
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
|
||||||
|
return entity_prompts, entity_masks, fg_mask, bg_mask
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -160,6 +193,10 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
width=1024,
|
width=1024,
|
||||||
num_inference_steps=30,
|
num_inference_steps=30,
|
||||||
t5_sequence_length=512,
|
t5_sequence_length=512,
|
||||||
|
inpaint_input=None,
|
||||||
|
entity_prompts=None,
|
||||||
|
entity_masks=None,
|
||||||
|
use_seperated_negtive_prompt=True,
|
||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
tile_stride=64,
|
tile_stride=64,
|
||||||
@@ -176,12 +213,13 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
# Prepare latent tensors
|
# Prepare latent tensors
|
||||||
if input_image is not None:
|
if input_image is not None or inpaint_input is not None:
|
||||||
|
input_image = input_image or inpaint_input
|
||||||
self.load_models_to_device(['vae_encoder'])
|
self.load_models_to_device(['vae_encoder'])
|
||||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
latents = self.encode_image(image, **tiler_kwargs)
|
input_latents = self.encode_image(image, **tiler_kwargs)
|
||||||
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
else:
|
else:
|
||||||
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
@@ -195,6 +233,14 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||||
|
|
||||||
|
# Entity control
|
||||||
|
negative_entity_prompts = None
|
||||||
|
negative_masks = None
|
||||||
|
if entity_masks is not None:
|
||||||
|
entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input)
|
||||||
|
if use_seperated_negtive_prompt and cfg_scale != 1.0:
|
||||||
|
negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
|
||||||
|
negative_masks = entity_masks
|
||||||
# Extra input
|
# Extra input
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||||
|
|
||||||
@@ -229,18 +275,20 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs
|
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs
|
||||||
)
|
)
|
||||||
|
if inpaint_input:
|
||||||
|
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {}
|
negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {}
|
||||||
noise_pred_nega = lets_dance_flux(
|
noise_pred_nega = lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks,
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
|
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
@@ -278,6 +326,8 @@ def lets_dance_flux(
|
|||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
tile_stride=64,
|
tile_stride=64,
|
||||||
|
entity_prompts=None,
|
||||||
|
entity_masks=None,
|
||||||
ipadapter_kwargs_list={},
|
ipadapter_kwargs_list={},
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -333,13 +383,38 @@ def lets_dance_flux(
|
|||||||
if dit.guidance_embedder is not None:
|
if dit.guidance_embedder is not None:
|
||||||
guidance = guidance * 1000
|
guidance = guidance * 1000
|
||||||
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
|
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
|
||||||
prompt_emb = dit.context_embedder(prompt_emb)
|
|
||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
||||||
|
|
||||||
|
repeat_dim = hidden_states.shape[1]
|
||||||
height, width = hidden_states.shape[-2:]
|
height, width = hidden_states.shape[-2:]
|
||||||
hidden_states = dit.patchify(hidden_states)
|
hidden_states = dit.patchify(hidden_states)
|
||||||
hidden_states = dit.x_embedder(hidden_states)
|
hidden_states = dit.x_embedder(hidden_states)
|
||||||
|
|
||||||
|
# Entity Control
|
||||||
|
max_masks = 0
|
||||||
|
attention_mask = None
|
||||||
|
prompt_embs = [prompt_emb]
|
||||||
|
if entity_masks is not None:
|
||||||
|
# entity_masks
|
||||||
|
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||||
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
# global mask
|
||||||
|
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
entity_masks = entity_masks + [global_mask] # append global to last
|
||||||
|
# attention mask
|
||||||
|
attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||||
|
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
# embds: n_masks * b * seq * d
|
||||||
|
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
prompt_embs = local_embs + prompt_embs # append global to last
|
||||||
|
prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||||
|
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||||
|
|
||||||
|
# positional embedding
|
||||||
|
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||||
|
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||||
|
|
||||||
# Joint Blocks
|
# Joint Blocks
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
hidden_states, prompt_emb = block(
|
hidden_states, prompt_emb = block(
|
||||||
@@ -347,6 +422,7 @@ def lets_dance_flux(
|
|||||||
prompt_emb,
|
prompt_emb,
|
||||||
conditioning,
|
conditioning,
|
||||||
image_rotary_emb,
|
image_rotary_emb,
|
||||||
|
attention_mask,
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
|
||||||
# ControlNet
|
# ControlNet
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
if controlnet is not None and controlnet_frames is not None:
|
||||||
@@ -361,6 +437,7 @@ def lets_dance_flux(
|
|||||||
prompt_emb,
|
prompt_emb,
|
||||||
conditioning,
|
conditioning,
|
||||||
image_rotary_emb,
|
image_rotary_emb,
|
||||||
|
attention_mask,
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
|
||||||
block_id + num_joint_blocks, None))
|
block_id + num_joint_blocks, None))
|
||||||
# ControlNet
|
# ControlNet
|
||||||
|
|||||||
54
examples/EntityControl/entity_control_flux.py
Normal file
54
examples/EntityControl/entity_control_flux.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline
|
||||||
|
from examples.EntityControl.utils import visualize_masks
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# lora_path = download_customized_models(
|
||||||
|
# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1",
|
||||||
|
# origin_file_path="merged_lora.safetensors",
|
||||||
|
# local_dir="models/lora"
|
||||||
|
# )[0]
|
||||||
|
|
||||||
|
lora_path = '/root/model_bf16.safetensors'
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
|
])
|
||||||
|
model_manager.load_lora(lora_path, lora_alpha=1.)
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask'
|
||||||
|
image_shape = 1024
|
||||||
|
guidance = 3.5
|
||||||
|
cfg = 3.0
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
names = ['row_2_1']
|
||||||
|
seeds = [0]
|
||||||
|
# use this to apply regional attention in negative prompt prediction for better results with more time
|
||||||
|
use_seperated_negtive_prompt = False
|
||||||
|
for name, seed in zip(names, seeds):
|
||||||
|
out_dir = f'workdirs/entity_control/{name}'
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
cur_dir = os.path.join(mask_dir, name)
|
||||||
|
metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json')))
|
||||||
|
for seed in range(3, 10):
|
||||||
|
prompt = metas['global_prompt']
|
||||||
|
mask_prompts = metas['mask_prompts']
|
||||||
|
masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))]
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
cfg_scale=cfg,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape,
|
||||||
|
entity_prompts=mask_prompts, entity_masks=masks,
|
||||||
|
use_seperated_negtive_prompt=use_seperated_negtive_prompt
|
||||||
|
)
|
||||||
|
use_sep = f'_sepneg' if use_seperated_negtive_prompt else ''
|
||||||
|
visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png"))
|
||||||
59
examples/EntityControl/entity_inpaint_flux.py
Normal file
59
examples/EntityControl/entity_inpaint_flux.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline
|
||||||
|
from examples.EntityControl.utils import visualize_masks
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# lora_path = download_customized_models(
|
||||||
|
# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1",
|
||||||
|
# origin_file_path="merged_lora.safetensors",
|
||||||
|
# local_dir="models/lora"
|
||||||
|
# )[0]
|
||||||
|
|
||||||
|
lora_path = '/root/model_bf16.safetensors'
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
|
])
|
||||||
|
model_manager.load_lora(lora_path, lora_alpha=1.)
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask'
|
||||||
|
image_shape = 1024
|
||||||
|
guidance = 3.5
|
||||||
|
cfg = 3.0
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
names = ['inpaint2']
|
||||||
|
seeds = [0]
|
||||||
|
use_seperated_negtive_prompt = False
|
||||||
|
for name, seed in zip(names, seeds):
|
||||||
|
out_dir = f'workdirs/paper_app/inpaint/elc/{name}'
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
cur_dir = os.path.join(mask_dir, name)
|
||||||
|
metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json')))
|
||||||
|
inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB')
|
||||||
|
prompt = metas['global_prompt']
|
||||||
|
prompt = 'A person with a dog walking on the cloud. A rocket in the sky'
|
||||||
|
mask_prompts = metas['mask_prompts']
|
||||||
|
masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))]
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
cfg_scale=cfg,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_inference_steps=50,
|
||||||
|
embedded_guidance=guidance,
|
||||||
|
height=image_shape,
|
||||||
|
width=image_shape,
|
||||||
|
entity_prompts=mask_prompts,
|
||||||
|
entity_masks=masks,
|
||||||
|
inpaint_input=inpaint_input,
|
||||||
|
use_seperated_negtive_prompt=use_seperated_negtive_prompt,
|
||||||
|
)
|
||||||
|
use_sep = f'_sepneg' if use_seperated_negtive_prompt else ''
|
||||||
|
visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png"))
|
||||||
59
examples/EntityControl/utils.py
Normal file
59
examples/EntityControl/utils.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import random
|
||||||
|
|
||||||
|
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||||
|
# Create a blank image for overlays
|
||||||
|
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||||
|
|
||||||
|
colors = [
|
||||||
|
(165, 238, 173, 80),
|
||||||
|
(76, 102, 221, 80),
|
||||||
|
(221, 160, 77, 80),
|
||||||
|
(204, 93, 71, 80),
|
||||||
|
(145, 187, 149, 80),
|
||||||
|
(134, 141, 172, 80),
|
||||||
|
(157, 137, 109, 80),
|
||||||
|
(153, 104, 95, 80),
|
||||||
|
(165, 238, 173, 80),
|
||||||
|
(76, 102, 221, 80),
|
||||||
|
(221, 160, 77, 80),
|
||||||
|
(204, 93, 71, 80),
|
||||||
|
(145, 187, 149, 80),
|
||||||
|
(134, 141, 172, 80),
|
||||||
|
(157, 137, 109, 80),
|
||||||
|
(153, 104, 95, 80),
|
||||||
|
]
|
||||||
|
# Generate random colors for each mask
|
||||||
|
if use_random_colors:
|
||||||
|
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||||
|
|
||||||
|
# Font settings
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||||
|
except IOError:
|
||||||
|
font = ImageFont.load_default(font_size)
|
||||||
|
|
||||||
|
# Overlay each mask onto the overlay image
|
||||||
|
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||||
|
# Convert mask to RGBA mode
|
||||||
|
mask_rgba = mask.convert('RGBA')
|
||||||
|
mask_data = mask_rgba.getdata()
|
||||||
|
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||||
|
mask_rgba.putdata(new_data)
|
||||||
|
|
||||||
|
# Draw the mask prompt text on the mask
|
||||||
|
draw = ImageDraw.Draw(mask_rgba)
|
||||||
|
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||||
|
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||||
|
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||||
|
|
||||||
|
# Alpha composite the overlay with this mask
|
||||||
|
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||||
|
|
||||||
|
# Composite the overlay onto the original image
|
||||||
|
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||||
|
|
||||||
|
# Save or display the resulting image
|
||||||
|
result.save(output_path)
|
||||||
|
|
||||||
|
return result
|
||||||
Reference in New Issue
Block a user