mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
diffusion skills framework
This commit is contained in:
@@ -364,78 +364,7 @@ class Flux2FeedForward(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Flux2AttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
|
||||
hidden_states = attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2Attention(torch.nn.Module):
|
||||
_default_processor_cls = Flux2AttnProcessor
|
||||
_available_processors = [Flux2AttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
@@ -449,7 +378,6 @@ class Flux2Attention(torch.nn.Module):
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -485,59 +413,45 @@ class Flux2Attention(torch.nn.Module):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
kv_cache = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2ParallelSelfAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
self, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
# Handle the attention logic
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
query = query.unflatten(-1, (self.heads, -1))
|
||||
key = key.unflatten(-1, (self.heads, -1))
|
||||
value = value.unflatten(-1, (self.heads, -1))
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
query = self.norm_q(query)
|
||||
key = self.norm_k(key)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
if self.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (self.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (self.heads, -1))
|
||||
|
||||
encoder_query = self.norm_added_q(encoder_query)
|
||||
encoder_key = self.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
|
||||
if kv_cache is not None:
|
||||
key = torch.concat([key, kv_cache[0]], dim=1)
|
||||
value = torch.concat([value, kv_cache[1]], dim=1)
|
||||
hidden_states = attention_forward(
|
||||
query,
|
||||
key,
|
||||
@@ -547,30 +461,22 @@ class Flux2ParallelSelfAttnProcessor:
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttention(torch.nn.Module):
|
||||
"""
|
||||
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
||||
|
||||
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
|
||||
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
|
||||
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
|
||||
"""
|
||||
|
||||
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
||||
# Does not support QKV fusion as the QKV projections are always fused
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
@@ -614,20 +520,54 @@ class Flux2ParallelSelfAttention(torch.nn.Module):
|
||||
# Fused attention output projection + MLP output projection
|
||||
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
kv_cache = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states = self.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states, [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], dim=-1
|
||||
)
|
||||
|
||||
# Handle the attention logic
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.unflatten(-1, (self.heads, -1))
|
||||
key = key.unflatten(-1, (self.heads, -1))
|
||||
value = value.unflatten(-1, (self.heads, -1))
|
||||
|
||||
query = self.norm_q(query)
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
if kv_cache is not None:
|
||||
key = torch.concat([key, kv_cache[0]], dim=1)
|
||||
value = torch.concat([value, kv_cache[1]], dim=1)
|
||||
hidden_states = attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2SingleTransformerBlock(nn.Module):
|
||||
@@ -657,7 +597,6 @@ class Flux2SingleTransformerBlock(nn.Module):
|
||||
eps=eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
mlp_mult_factor=2,
|
||||
processor=Flux2ParallelSelfAttnProcessor(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -669,6 +608,7 @@ class Flux2SingleTransformerBlock(nn.Module):
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
split_hidden_states: bool = False,
|
||||
text_seq_len: Optional[int] = None,
|
||||
kv_cache = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
|
||||
# concatenated
|
||||
@@ -685,6 +625,7 @@ class Flux2SingleTransformerBlock(nn.Module):
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
kv_cache=kv_cache,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -725,7 +666,6 @@ class Flux2TransformerBlock(nn.Module):
|
||||
added_proj_bias=bias,
|
||||
out_bias=bias,
|
||||
eps=eps,
|
||||
processor=Flux2AttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
@@ -742,6 +682,7 @@ class Flux2TransformerBlock(nn.Module):
|
||||
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
kv_cache = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
@@ -762,6 +703,7 @@ class Flux2TransformerBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
kv_cache=kv_cache,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -969,6 +911,7 @@ class Flux2DiT(torch.nn.Module):
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
kv_cache = None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
@@ -1013,7 +956,7 @@ class Flux2DiT(torch.nn.Module):
|
||||
)
|
||||
|
||||
# 4. Double Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
for block_id, block in enumerate(self.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -1024,12 +967,13 @@ class Flux2DiT(torch.nn.Module):
|
||||
temb_mod_params_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
kv_cache=None if kv_cache is None else kv_cache.get(f"double_{block_id}"),
|
||||
)
|
||||
# Concatenate text and image streams for single-block inference
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 5. Single Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
for block_id, block in enumerate(self.single_transformer_blocks):
|
||||
hidden_states = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -1039,6 +983,7 @@ class Flux2DiT(torch.nn.Module):
|
||||
temb_mod_params=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
kv_cache=None if kv_cache is None else kv_cache.get(f"single_{block_id}"),
|
||||
)
|
||||
# Remove text tokens from concatenated stream
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
Reference in New Issue
Block a user