diffusion skills framework

This commit is contained in:
Artiprocher
2026-03-17 13:34:25 +08:00
parent 7a80f10fa4
commit f88b99cb4f
11 changed files with 422 additions and 138 deletions

View File

@@ -364,78 +364,7 @@ class Flux2FeedForward(nn.Module):
return x
class Flux2AttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "Flux2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
hidden_states = attention_forward(
query,
key,
value,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class Flux2Attention(torch.nn.Module):
_default_processor_cls = Flux2AttnProcessor
_available_processors = [Flux2AttnProcessor]
def __init__(
self,
query_dim: int,
@@ -449,7 +378,6 @@ class Flux2Attention(torch.nn.Module):
eps: float = 1e-5,
out_dim: int = None,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
@@ -485,59 +413,45 @@ class Flux2Attention(torch.nn.Module):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
kv_cache = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class Flux2ParallelSelfAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "Flux2ParallelSelfAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Parallel in (QKV + MLP in) projection
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
qkv, mlp_hidden_states = torch.split(
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
self, hidden_states, encoder_hidden_states
)
# Handle the attention logic
query, key, value = qkv.chunk(3, dim=-1)
query = query.unflatten(-1, (self.heads, -1))
key = key.unflatten(-1, (self.heads, -1))
value = value.unflatten(-1, (self.heads, -1))
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = self.norm_q(query)
key = self.norm_k(key)
query = attn.norm_q(query)
key = attn.norm_k(key)
if self.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
encoder_key = encoder_key.unflatten(-1, (self.heads, -1))
encoder_value = encoder_value.unflatten(-1, (self.heads, -1))
encoder_query = self.norm_added_q(encoder_query)
encoder_key = self.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype)
if kv_cache is not None:
key = torch.concat([key, kv_cache[0]], dim=1)
value = torch.concat([value, kv_cache[1]], dim=1)
hidden_states = attention_forward(
query,
key,
@@ -547,30 +461,22 @@ class Flux2ParallelSelfAttnProcessor:
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# Handle the feedforward (FF) logic
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
# Concatenate and parallel output projection
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
hidden_states = attn.to_out(hidden_states)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class Flux2ParallelSelfAttention(torch.nn.Module):
"""
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
"""
_default_processor_cls = Flux2ParallelSelfAttnProcessor
_available_processors = [Flux2ParallelSelfAttnProcessor]
# Does not support QKV fusion as the QKV projections are always fused
_supports_qkv_fusion = False
def __init__(
self,
query_dim: int,
@@ -614,20 +520,54 @@ class Flux2ParallelSelfAttention(torch.nn.Module):
# Fused attention output projection + MLP output projection
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
kv_cache = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
# Parallel in (QKV + MLP in) projection
hidden_states = self.to_qkv_mlp_proj(hidden_states)
qkv, mlp_hidden_states = torch.split(
hidden_states, [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], dim=-1
)
# Handle the attention logic
query, key, value = qkv.chunk(3, dim=-1)
query = query.unflatten(-1, (self.heads, -1))
key = key.unflatten(-1, (self.heads, -1))
value = value.unflatten(-1, (self.heads, -1))
query = self.norm_q(query)
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
if kv_cache is not None:
key = torch.concat([key, kv_cache[0]], dim=1)
value = torch.concat([value, kv_cache[1]], dim=1)
hidden_states = attention_forward(
query,
key,
value,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# Handle the feedforward (FF) logic
mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states)
# Concatenate and parallel output projection
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
hidden_states = self.to_out(hidden_states)
return hidden_states
class Flux2SingleTransformerBlock(nn.Module):
@@ -657,7 +597,6 @@ class Flux2SingleTransformerBlock(nn.Module):
eps=eps,
mlp_ratio=mlp_ratio,
mlp_mult_factor=2,
processor=Flux2ParallelSelfAttnProcessor(),
)
def forward(
@@ -669,6 +608,7 @@ class Flux2SingleTransformerBlock(nn.Module):
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
split_hidden_states: bool = False,
text_seq_len: Optional[int] = None,
kv_cache = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
# concatenated
@@ -685,6 +625,7 @@ class Flux2SingleTransformerBlock(nn.Module):
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
kv_cache=kv_cache,
**joint_attention_kwargs,
)
@@ -725,7 +666,6 @@ class Flux2TransformerBlock(nn.Module):
added_proj_bias=bias,
out_bias=bias,
eps=eps,
processor=Flux2AttnProcessor(),
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
@@ -742,6 +682,7 @@ class Flux2TransformerBlock(nn.Module):
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
kv_cache = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
@@ -762,6 +703,7 @@ class Flux2TransformerBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
kv_cache=kv_cache,
**joint_attention_kwargs,
)
@@ -969,6 +911,7 @@ class Flux2DiT(torch.nn.Module):
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
kv_cache = None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
@@ -1013,7 +956,7 @@ class Flux2DiT(torch.nn.Module):
)
# 4. Double Stream Transformer Blocks
for index_block, block in enumerate(self.transformer_blocks):
for block_id, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
@@ -1024,12 +967,13 @@ class Flux2DiT(torch.nn.Module):
temb_mod_params_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
kv_cache=None if kv_cache is None else kv_cache.get(f"double_{block_id}"),
)
# Concatenate text and image streams for single-block inference
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 5. Single Stream Transformer Blocks
for index_block, block in enumerate(self.single_transformer_blocks):
for block_id, block in enumerate(self.single_transformer_blocks):
hidden_states = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
@@ -1039,6 +983,7 @@ class Flux2DiT(torch.nn.Module):
temb_mod_params=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
kv_cache=None if kv_cache is None else kv_cache.get(f"single_{block_id}"),
)
# Remove text tokens from concatenated stream
hidden_states = hidden_states[:, num_txt_tokens:, ...]