mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'main' into cuda_replace
This commit is contained in:
@@ -823,7 +823,13 @@ class Flux2PosEmbed(nn.Module):
|
||||
|
||||
|
||||
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 256,
|
||||
embedding_dim: int = 6144,
|
||||
bias: bool = False,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
@@ -831,20 +837,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
else:
|
||||
self.guidance_embedder = None
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
return time_guidance_emb
|
||||
if guidance is not None and self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
return time_guidance_emb
|
||||
else:
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class Flux2Modulation(nn.Module):
|
||||
@@ -882,6 +892,7 @@ class Flux2DiT(torch.nn.Module):
|
||||
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
||||
rope_theta: int = 2000,
|
||||
eps: float = 1e-6,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -892,7 +903,10 @@ class Flux2DiT(torch.nn.Module):
|
||||
|
||||
# 2. Combined timestep + guidance embedding
|
||||
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
||||
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
||||
in_channels=timestep_guidance_channels,
|
||||
embedding_dim=self.inner_dim,
|
||||
bias=False,
|
||||
guidance_embeds=guidance_embeds,
|
||||
)
|
||||
|
||||
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
||||
@@ -953,34 +967,9 @@ class Flux2DiT(torch.nn.Module):
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
) -> Union[torch.Tensor]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
):
|
||||
# 0. Handle input arguments
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
@@ -992,7 +981,9 @@ class Flux2DiT(torch.nn.Module):
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = self.time_guidance_embed(timestep, guidance)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
@@ -92,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from .general_modules import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
@@ -3,38 +3,71 @@ import torch
|
||||
|
||||
|
||||
class ZImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, model_size="4B"):
|
||||
super().__init__()
|
||||
config = Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
config_dict = {
|
||||
"4B": Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
}),
|
||||
"8B": Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 4096,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 12288,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": False,
|
||||
"transformers_version": "4.56.1",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
}
|
||||
config = config_dict[model_size]
|
||||
self.model = Qwen3Model(config)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user