mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
support klein 4b models
This commit is contained in:
@@ -510,6 +510,13 @@ flux2_series = [
|
|||||||
"model_name": "flux2_vae",
|
"model_name": "flux2_vae",
|
||||||
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
||||||
|
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
||||||
|
"model_name": "flux2_dit",
|
||||||
|
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||||
|
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
z_image_series = [
|
z_image_series = [
|
||||||
|
|||||||
@@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||||
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 256,
|
||||||
|
embedding_dim: int = 6144,
|
||||||
|
bias: bool = False,
|
||||||
|
guidance_embeds: bool = True,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
@@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
|||||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||||
)
|
)
|
||||||
|
|
||||||
self.guidance_embedder = TimestepEmbedding(
|
if guidance_embeds:
|
||||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
self.guidance_embedder = TimestepEmbedding(
|
||||||
)
|
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.guidance_embedder = None
|
||||||
|
|
||||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
||||||
|
|
||||||
guidance_proj = self.time_proj(guidance)
|
if guidance is not None and self.guidance_embedder is not None:
|
||||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
guidance_proj = self.time_proj(guidance)
|
||||||
|
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||||
time_guidance_emb = timesteps_emb + guidance_emb
|
time_guidance_emb = timesteps_emb + guidance_emb
|
||||||
|
return time_guidance_emb
|
||||||
return time_guidance_emb
|
else:
|
||||||
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
class Flux2Modulation(nn.Module):
|
class Flux2Modulation(nn.Module):
|
||||||
@@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module):
|
|||||||
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
||||||
rope_theta: int = 2000,
|
rope_theta: int = 2000,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
|
guidance_embeds: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
@@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module):
|
|||||||
|
|
||||||
# 2. Combined timestep + guidance embedding
|
# 2. Combined timestep + guidance embedding
|
||||||
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
||||||
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
in_channels=timestep_guidance_channels,
|
||||||
|
embedding_dim=self.inner_dim,
|
||||||
|
bias=False,
|
||||||
|
guidance_embeds=guidance_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
||||||
@@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module):
|
|||||||
txt_ids: torch.Tensor = None,
|
txt_ids: torch.Tensor = None,
|
||||||
guidance: torch.Tensor = None,
|
guidance: torch.Tensor = None,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
return_dict: bool = True,
|
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
) -> Union[torch.Tensor]:
|
):
|
||||||
"""
|
|
||||||
The [`FluxTransformer2DModel`] forward method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
|
||||||
Input `hidden_states`.
|
|
||||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
|
||||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
||||||
timestep ( `torch.LongTensor`):
|
|
||||||
Used to indicate denoising step.
|
|
||||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
|
||||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
|
||||||
joint_attention_kwargs (`dict`, *optional*):
|
|
||||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
||||||
`self.processor` in
|
|
||||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
|
||||||
tuple.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
|
||||||
`tuple` where the first element is the sample tensor.
|
|
||||||
"""
|
|
||||||
# 0. Handle input arguments
|
# 0. Handle input arguments
|
||||||
if joint_attention_kwargs is not None:
|
if joint_attention_kwargs is not None:
|
||||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||||
@@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module):
|
|||||||
|
|
||||||
# 1. Calculate timestep embedding and modulation parameters
|
# 1. Calculate timestep embedding and modulation parameters
|
||||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||||
|
|
||||||
temb = self.time_guidance_embed(timestep, guidance)
|
temb = self.time_guidance_embed(timestep, guidance)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from ..diffusion import FlowMatchScheduler
|
|||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
|
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
from ..models.flux2_text_encoder import Flux2TextEncoder
|
from ..models.flux2_text_encoder import Flux2TextEncoder
|
||||||
from ..models.flux2_dit import Flux2DiT
|
from ..models.flux2_dit import Flux2DiT
|
||||||
from ..models.flux2_vae import Flux2VAE
|
from ..models.flux2_vae import Flux2VAE
|
||||||
@@ -53,11 +53,12 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
# Fetch models
|
# Fetch models
|
||||||
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
|
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
|
||||||
|
pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder")
|
||||||
pipe.dit = model_pool.fetch_model("flux2_dit")
|
pipe.dit = model_pool.fetch_model("flux2_dit")
|
||||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||||
if tokenizer_config is not None:
|
if tokenizer_config is not None:
|
||||||
tokenizer_config.download_if_necessary()
|
tokenizer_config.download_if_necessary()
|
||||||
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path)
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
# VRAM Management
|
# VRAM Management
|
||||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
|||||||
Reference in New Issue
Block a user