temp commit for entity control

This commit is contained in:
mi804
2024-12-25 17:19:31 +08:00
parent 1b6e96a820
commit e3d89cec0c
5 changed files with 349 additions and 44 deletions

View File

@@ -40,7 +40,7 @@ class RoPEEmbedding(torch.nn.Module):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class FluxJointAttention(torch.nn.Module):
@@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states_a.shape[0]
# Part A
@@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module):
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
@@ -103,7 +103,7 @@ class FluxJointAttention(torch.nn.Module):
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxJointTransformerBlock(torch.nn.Module):
@@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -147,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
@@ -184,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
class AdaLayerNormSingle(torch.nn.Module):
@@ -200,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module):
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class FluxSingleTransformerBlock(torch.nn.Module):
@@ -225,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
@@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
if ipadapter_kwargs_list is not None:
@@ -243,21 +243,21 @@ class FluxSingleTransformerBlock(torch.nn.Module):
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
class AdaLayerNormContinuous(torch.nn.Module):
@@ -300,7 +300,7 @@ class FluxDiT(torch.nn.Module):
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
@@ -317,7 +317,7 @@ class FluxDiT(torch.nn.Module):
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def tiled_forward(
self,
@@ -337,12 +337,45 @@ class FluxDiT(torch.nn.Module):
)
return hidden_states
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
total_seq_len = N * prompt_seq_len + image_seq_len
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
image_start = N * prompt_seq_len
image_end = N * prompt_seq_len + image_seq_len
# prompt-image mask
for i in range(N):
prompt_start = i * prompt_seq_len
prompt_end = (i + 1) * prompt_seq_len
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt mask
for i in range(N):
for j in range(N):
if i != j:
prompt_start_i = i * prompt_seq_len
prompt_end_i = (i + 1) * prompt_seq_len
prompt_start_j = j * prompt_seq_len
prompt_end_j = (j + 1) * prompt_seq_len
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
return attention_mask
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64,
tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None,
use_gradient_checkpointing=False,
**kwargs
):
@@ -353,46 +386,70 @@ class FluxDiT(torch.nn.Module):
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
repeat_dim = hidden_states.shape[1]
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.final_norm_out(hidden_states, conditioning)
@@ -400,7 +457,7 @@ class FluxDiT(torch.nn.Module):
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
@@ -440,16 +497,16 @@ class FluxDiT(torch.nn.Module):
class Linear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
@@ -457,7 +514,7 @@ class FluxDiT(torch.nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
@@ -483,7 +540,6 @@ class FluxDiT(torch.nn.Module):
@staticmethod
def state_dict_converter():
return FluxDiTStateDictConverter()
class FluxDiTStateDictConverter:
@@ -587,7 +643,7 @@ class FluxDiTStateDictConverter:
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",

View File

@@ -139,6 +139,39 @@ class FluxImagePipeline(BasePipeline):
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0)
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
# inpaint noise
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
# merge noise
weight = torch.ones_like(inpaint_noise)
inpaint_noise[fg_mask] = pred_noise[fg_mask]
inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
weight[bg_mask] += background_weight
inpaint_noise /= weight
return inpaint_noise
def preprocess_masks(self, masks, height, width, dim):
out_masks = []
for mask in masks:
mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
out_masks.append(mask)
return out_masks
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, inpaint_input=None):
fg_mask, bg_mask = None, None
if inpaint_input is not None:
from copy import deepcopy
masks_ = deepcopy(entity_masks)
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
fg_masks = (fg_masks > 0).float()
fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
bg_mask = ~fg_mask
entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
return entity_prompts, entity_masks, fg_mask, bg_mask
@torch.no_grad()
def __call__(
self,
@@ -160,6 +193,10 @@ class FluxImagePipeline(BasePipeline):
width=1024,
num_inference_steps=30,
t5_sequence_length=512,
inpaint_input=None,
entity_prompts=None,
entity_masks=None,
use_seperated_negtive_prompt=True,
tiled=False,
tile_size=128,
tile_stride=64,
@@ -176,12 +213,13 @@ class FluxImagePipeline(BasePipeline):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
if input_image is not None or inpaint_input is not None:
input_image = input_image or inpaint_input
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
input_latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
@@ -195,6 +233,14 @@ class FluxImagePipeline(BasePipeline):
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Entity control
negative_entity_prompts = None
negative_masks = None
if entity_masks is not None:
entity_prompts, entity_masks, fg_mask, bg_mask = self.prepare_entity_inputs(entity_prompts, entity_masks, width, height, t5_sequence_length, inpaint_input)
if use_seperated_negtive_prompt and cfg_scale != 1.0:
negative_entity_prompts = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1)
negative_masks = entity_masks
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
@@ -229,18 +275,20 @@ class FluxImagePipeline(BasePipeline):
# Classifier-free guidance
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
hidden_states=latents, timestep=timestep, entity_prompts=entity_prompts, entity_masks=entity_masks,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs
)
if inpaint_input:
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
if cfg_scale != 1.0:
negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {}
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
hidden_states=latents, timestep=timestep, entity_prompts=negative_entity_prompts, entity_masks=negative_masks,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
@@ -278,6 +326,8 @@ def lets_dance_flux(
tiled=False,
tile_size=128,
tile_stride=64,
entity_prompts=None,
entity_masks=None,
ipadapter_kwargs_list={},
**kwargs
):
@@ -333,13 +383,38 @@ def lets_dance_flux(
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
repeat_dim = hidden_states.shape[1]
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
# Entity Control
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = dit.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [dit.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
@@ -347,6 +422,7 @@ def lets_dance_flux(
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
# ControlNet
if controlnet is not None and controlnet_frames is not None:
@@ -361,6 +437,7 @@ def lets_dance_flux(
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
block_id + num_joint_blocks, None))
# ControlNet

View File

@@ -0,0 +1,54 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline
from examples.EntityControl.utils import visualize_masks
import os
import json
from PIL import Image
# lora_path = download_customized_models(
# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1",
# origin_file_path="merged_lora.safetensors",
# local_dir="models/lora"
# )[0]
lora_path = '/root/model_bf16.safetensors'
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"t2i_models/FLUX/FLUX.1-dev/text_encoder_2",
"t2i_models/FLUX/FLUX.1-dev/ae.safetensors",
"t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
model_manager.load_lora(lora_path, lora_alpha=1.)
pipe = FluxImagePipeline.from_model_manager(model_manager)
mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask'
image_shape = 1024
guidance = 3.5
cfg = 3.0
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
names = ['row_2_1']
seeds = [0]
# use this to apply regional attention in negative prompt prediction for better results with more time
use_seperated_negtive_prompt = False
for name, seed in zip(names, seeds):
out_dir = f'workdirs/entity_control/{name}'
os.makedirs(out_dir, exist_ok=True)
cur_dir = os.path.join(mask_dir, name)
metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json')))
for seed in range(3, 10):
prompt = metas['global_prompt']
mask_prompts = metas['mask_prompts']
masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))]
torch.manual_seed(seed)
image = pipe(
prompt=prompt,
cfg_scale=cfg,
negative_prompt=negative_prompt,
num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape,
entity_prompts=mask_prompts, entity_masks=masks,
use_seperated_negtive_prompt=use_seperated_negtive_prompt
)
use_sep = f'_sepneg' if use_seperated_negtive_prompt else ''
visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png"))

View File

@@ -0,0 +1,59 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline
from examples.EntityControl.utils import visualize_masks
import os
import json
from PIL import Image
# lora_path = download_customized_models(
# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1",
# origin_file_path="merged_lora.safetensors",
# local_dir="models/lora"
# )[0]
lora_path = '/root/model_bf16.safetensors'
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"t2i_models/FLUX/FLUX.1-dev/text_encoder_2",
"t2i_models/FLUX/FLUX.1-dev/ae.safetensors",
"t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
model_manager.load_lora(lora_path, lora_alpha=1.)
pipe = FluxImagePipeline.from_model_manager(model_manager)
mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask'
image_shape = 1024
guidance = 3.5
cfg = 3.0
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
names = ['inpaint2']
seeds = [0]
use_seperated_negtive_prompt = False
for name, seed in zip(names, seeds):
out_dir = f'workdirs/paper_app/inpaint/elc/{name}'
os.makedirs(out_dir, exist_ok=True)
cur_dir = os.path.join(mask_dir, name)
metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json')))
inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB')
prompt = metas['global_prompt']
prompt = 'A person with a dog walking on the cloud. A rocket in the sky'
mask_prompts = metas['mask_prompts']
masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))]
torch.manual_seed(seed)
image = pipe(
prompt=prompt,
cfg_scale=cfg,
negative_prompt=negative_prompt,
num_inference_steps=50,
embedded_guidance=guidance,
height=image_shape,
width=image_shape,
entity_prompts=mask_prompts,
entity_masks=masks,
inpaint_input=inpaint_input,
use_seperated_negtive_prompt=use_seperated_negtive_prompt,
)
use_sep = f'_sepneg' if use_seperated_negtive_prompt else ''
visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png"))

View File

@@ -0,0 +1,59 @@
from PIL import Image, ImageDraw, ImageFont
import random
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
# Create a blank image for overlays
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
colors = [
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
]
# Generate random colors for each mask
if use_random_colors:
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
# Font settings
try:
font = ImageFont.truetype("arial", font_size) # Adjust as needed
except IOError:
font = ImageFont.load_default(font_size)
# Overlay each mask onto the overlay image
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
# Convert mask to RGBA mode
mask_rgba = mask.convert('RGBA')
mask_data = mask_rgba.getdata()
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
mask_rgba.putdata(new_data)
# Draw the mask prompt text on the mask
draw = ImageDraw.Draw(mask_rgba)
mask_bbox = mask.getbbox() # Get the bounding box of the mask
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
# Alpha composite the overlay with this mask
overlay = Image.alpha_composite(overlay, mask_rgba)
# Composite the overlay onto the original image
result = Image.alpha_composite(image.convert('RGBA'), overlay)
# Save or display the resulting image
result.save(output_path)
return result