mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -40,7 +40,7 @@ class RoPEEmbedding(torch.nn.Module):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxJointAttention(torch.nn.Module):
|
||||
@@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
# Part A
|
||||
@@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module):
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||
@@ -103,7 +103,7 @@ class FluxJointAttention(torch.nn.Module):
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxJointTransformerBlock(torch.nn.Module):
|
||||
@@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
@@ -147,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxSingleAttention(torch.nn.Module):
|
||||
@@ -184,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormSingle(torch.nn.Module):
|
||||
@@ -200,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module):
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
@@ -225,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
if ipadapter_kwargs_list is not None:
|
||||
@@ -243,21 +243,21 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||
|
||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||
hidden_states_a = residual + hidden_states_a
|
||||
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(torch.nn.Module):
|
||||
@@ -300,7 +300,7 @@ class FluxDiT(torch.nn.Module):
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
@@ -317,7 +317,7 @@ class FluxDiT(torch.nn.Module):
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
|
||||
def tiled_forward(
|
||||
self,
|
||||
@@ -338,11 +338,75 @@ class FluxDiT(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
total_seq_len = N * prompt_seq_len + image_seq_len
|
||||
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
image_start = N * prompt_seq_len
|
||||
image_end = N * prompt_seq_len + image_seq_len
|
||||
# prompt-image mask
|
||||
for i in range(N):
|
||||
prompt_start = i * prompt_seq_len
|
||||
prompt_end = (i + 1) * prompt_seq_len
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt mask
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i != j:
|
||||
prompt_start_i = i * prompt_seq_len
|
||||
prompt_end_i = (i + 1) * prompt_seq_len
|
||||
prompt_start_j = j * prompt_seq_len
|
||||
prompt_end_j = (j + 1) * prompt_seq_len
|
||||
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||
repeat_dim = hidden_states.shape[1]
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
if entity_masks is not None:
|
||||
# entity_masks
|
||||
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
# global mask
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
entity_masks = entity_masks + [global_mask] # append global to last
|
||||
# attention mask
|
||||
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
# embds: n_masks * b * seq * d
|
||||
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
prompt_embs = local_embs + prompt_embs # append global to last
|
||||
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||
|
||||
# positional embedding
|
||||
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
return prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
@@ -353,46 +417,51 @@ class FluxDiT(torch.nn.Module):
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||
@@ -400,7 +469,7 @@ class FluxDiT(torch.nn.Module):
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
@@ -440,16 +509,16 @@ class FluxDiT(torch.nn.Module):
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
@@ -457,7 +526,7 @@ class FluxDiT(torch.nn.Module):
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
@@ -483,7 +552,6 @@ class FluxDiT(torch.nn.Module):
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxDiTStateDictConverter:
|
||||
@@ -587,7 +655,7 @@ class FluxDiTStateDictConverter:
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
|
||||
@@ -366,17 +366,21 @@ class ModelManager:
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
break
|
||||
if isinstance(file_path, list):
|
||||
for file_path_ in file_path:
|
||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||
else:
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
break
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||
|
||||
Reference in New Issue
Block a user