mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
qwen-image
This commit is contained in:
357
diffsynth/models/qwen_image_dit.py
Normal file
357
diffsynth/models/qwen_image_dit.py
Normal file
@@ -0,0 +1,357 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional, Union, List
|
||||
from einops import rearrange
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .flux_dit import AdaLayerNorm
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def apply_rotary_emb_qwen(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
|
||||
):
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class QwenEmbedRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.rope_cache = {}
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
index,
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
frame, height, width = video_fhw
|
||||
rope_key = f"{frame}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[
|
||||
freqs_neg[1][-(height - height//2):],
|
||||
freqs_pos[1][:height//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat(
|
||||
[
|
||||
freqs_neg[2][-(width - width//2):],
|
||||
freqs_pos[2][:width//2]
|
||||
],
|
||||
dim=0
|
||||
)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index: max_vid_index + max_len, ...]
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class QwenFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * 4)
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(ApproximateGELU(dim, inner_dim))
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class QwenDoubleStreamAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_a,
|
||||
dim_b,
|
||||
num_heads,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = nn.Linear(dim_a, dim_a)
|
||||
self.to_k = nn.Linear(dim_a, dim_a)
|
||||
self.to_v = nn.Linear(dim_a, dim_a)
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.add_q_proj = nn.Linear(dim_b, dim_b)
|
||||
self.add_k_proj = nn.Linear(dim_b, dim_b)
|
||||
self.add_v_proj = nn.Linear(dim_b, dim_b)
|
||||
self.norm_added_q = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_added_k = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
|
||||
self.to_add_out = nn.Linear(dim_b, dim_b)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
||||
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
||||
seq_txt = txt_q.shape[1]
|
||||
|
||||
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
|
||||
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
|
||||
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
||||
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
||||
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
||||
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
||||
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
||||
|
||||
joint_q = torch.cat([txt_q, img_q], dim=2)
|
||||
joint_k = torch.cat([txt_k, img_k], dim=2)
|
||||
joint_v = torch.cat([txt_v, img_v], dim=2)
|
||||
|
||||
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v)
|
||||
|
||||
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
|
||||
|
||||
txt_attn_output = joint_attn_out[:, :seq_txt, :]
|
||||
img_attn_output = joint_attn_out[:, seq_txt:, :]
|
||||
|
||||
img_attn_output = self.to_out(img_attn_output)
|
||||
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim),
|
||||
)
|
||||
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.attn = QwenDoubleStreamAttention(
|
||||
dim_a=dim,
|
||||
dim_b=dim,
|
||||
num_heads=num_attention_heads,
|
||||
head_dim=attention_head_dim,
|
||||
)
|
||||
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim, bias=True),
|
||||
)
|
||||
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
|
||||
img_normed = self.img_norm1(image)
|
||||
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
|
||||
|
||||
txt_normed = self.txt_norm1(text)
|
||||
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
||||
|
||||
img_attn_out, txt_attn_out = self.attn(
|
||||
image=img_modulated,
|
||||
text=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
image = image + img_gate * img_attn_out
|
||||
text = text + txt_gate * txt_attn_out
|
||||
|
||||
img_normed_2 = self.img_norm2(image)
|
||||
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
|
||||
|
||||
txt_normed_2 = self.txt_norm2(text)
|
||||
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
||||
|
||||
img_mlp_out = self.img_mlp(img_modulated_2)
|
||||
txt_mlp_out = self.txt_mlp(txt_modulated_2)
|
||||
|
||||
image = image + img_gate_2 * img_mlp_out
|
||||
text = text + txt_gate_2 * txt_mlp_out
|
||||
|
||||
return text, image
|
||||
|
||||
|
||||
class QwenImageDiT(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
||||
|
||||
self.img_in = nn.Linear(64, 3072)
|
||||
self.txt_in = nn.Linear(3584, 3072)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||
self.proj_out = nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = self.img_in(image)
|
||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageDiTStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
255
diffsynth/models/qwen_image_text_encoder.py
Normal file
255
diffsynth/models/qwen_image_text_encoder.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from transformers import Qwen2_5_VLModel
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class QwenImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
config = Qwen2_5_VLConfig(**{
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"text_config": {
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": None,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"layer_types": [
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention"
|
||||
],
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl_text",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": None,
|
||||
"torch_dtype": "float32",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": None,
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.54.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"depth": 32,
|
||||
"fullatt_block_indexes": [
|
||||
7,
|
||||
15,
|
||||
23,
|
||||
31
|
||||
],
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1280,
|
||||
"in_channels": 3,
|
||||
"in_chans": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3420,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_heads": 16,
|
||||
"out_hidden_size": 3584,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 2,
|
||||
"spatial_patch_size": 14,
|
||||
"temporal_patch_size": 2,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "float32",
|
||||
"window_size": 112
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.model = Qwen2_5_VLModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
||||
The rope index difference between sequence length and multimodal rope.
|
||||
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
||||
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = False
|
||||
output_hidden_states = True
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
return outputs.hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageTextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("visual."):
|
||||
k = "model." + k
|
||||
elif k.startswith("model."):
|
||||
k = k.replace("model.", "model.language_model.")
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
736
diffsynth/models/qwen_image_vae.py
Normal file
736
diffsynth/models/qwen_image_vae.py
Normal file
@@ -0,0 +1,736 @@
|
||||
import torch
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from torch import nn
|
||||
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
class QwenImageCausalConv3d(torch.nn.Conv3d):
|
||||
r"""
|
||||
A custom 3D causal convolution layer with feature caching support.
|
||||
|
||||
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
||||
caching for efficient inference.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
# Set up causal padding
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = torch.nn.functional.pad(x, padding)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
|
||||
class QwenImageRMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
|
||||
class QwenImageResidualBlock(nn.Module):
|
||||
r"""
|
||||
A custom residual block module.
|
||||
|
||||
Args:
|
||||
in_dim (int): Number of input channels.
|
||||
out_dim (int): Number of output channels.
|
||||
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
||||
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
dropout: float = 0.0,
|
||||
non_linearity: str = "silu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
|
||||
# layers
|
||||
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
|
||||
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
|
||||
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Apply shortcut connection
|
||||
h = self.conv_shortcut(x)
|
||||
|
||||
# First normalization and activation
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
# Second normalization and activation
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
# Dropout
|
||||
x = self.dropout(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv2(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv2(x)
|
||||
|
||||
# Add residual connection
|
||||
return x + h
|
||||
|
||||
|
||||
|
||||
class QwenImageAttentionBlock(nn.Module):
|
||||
r"""
|
||||
Causal self-attention with a single head.
|
||||
|
||||
Args:
|
||||
dim (int): The number of channels in the input tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = QwenImageRMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
batch_size, channels, time, height, width = x.size()
|
||||
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
||||
x = self.norm(x)
|
||||
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
||||
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
||||
|
||||
# output projection
|
||||
x = self.proj(x)
|
||||
|
||||
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
||||
x = x.view(batch_size, time, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
|
||||
class QwenImageUpsample(nn.Upsample):
|
||||
r"""
|
||||
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor to be upsampled.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Upsampled tensor with the same data type as the input.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
|
||||
class QwenImageResample(nn.Module):
|
||||
r"""
|
||||
A custom resampling module for 2D and 3D data.
|
||||
|
||||
Args:
|
||||
dim (int): The number of input/output channels.
|
||||
mode (str): The resampling mode. Must be one of:
|
||||
- 'none': No resampling (identity operation).
|
||||
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
||||
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
||||
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == "downsample2d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == "downsample3d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == "upsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = "Rep"
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat(
|
||||
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
||||
)
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
||||
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
||||
if feat_cache[idx] == "Rep":
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.resample(x)
|
||||
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageMidBlock(nn.Module):
|
||||
"""
|
||||
Middle block for WanVAE encoder and decoder.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input/output channels.
|
||||
dropout (float): Dropout rate.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# Create the components
|
||||
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
|
||||
attentions = []
|
||||
for _ in range(num_layers):
|
||||
attentions.append(QwenImageAttentionBlock(dim))
|
||||
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# First residual block
|
||||
x = self.resnets[0](x, feat_cache, feat_idx)
|
||||
|
||||
# Process through attention and residual blocks
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
x = attn(x)
|
||||
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageEncoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D encoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = torch.nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||
|
||||
# output blocks
|
||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.down_blocks:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
|
||||
Args:
|
||||
in_dim (int): Input dimension
|
||||
out_dim (int): Output dimension
|
||||
num_res_blocks (int): Number of residual blocks
|
||||
dropout (float): Dropout rate
|
||||
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
||||
non_linearity (str): Type of non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
num_res_blocks: int,
|
||||
dropout: float = 0.0,
|
||||
upsample_mode: Optional[str] = None,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# Create layers list
|
||||
resnets = []
|
||||
# Add residual blocks and attention if needed
|
||||
current_dim = in_dim
|
||||
for _ in range(num_res_blocks + 1):
|
||||
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||
current_dim = out_dim
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add upsampling layer if needed
|
||||
self.upsamplers = None
|
||||
if upsample_mode is not None:
|
||||
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
feat_cache (list, optional): Feature cache for causal convolutions
|
||||
feat_idx (list, optional): Feature index for cache management
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor
|
||||
"""
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.upsamplers[0](x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageDecoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D decoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
||||
|
||||
# upsample blocks
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i > 0:
|
||||
in_dim = in_dim // 2
|
||||
|
||||
# Determine if we need upsampling
|
||||
upsample_mode = None
|
||||
if i != len(dim_mult) - 1:
|
||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||
|
||||
# Create and add the upsampling block
|
||||
up_block = QwenImageUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Update scale for next iteration
|
||||
if upsample_mode is not None:
|
||||
scale *= 2.0
|
||||
|
||||
# output blocks
|
||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class QwenImageVAE(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.z_dim = z_dim
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
self.encoder = QwenImageEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
)
|
||||
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = QwenImageDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
)
|
||||
|
||||
mean = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
]
|
||||
std = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
]
|
||||
self.mean = torch.tensor(mean).view(1, 16, 1, 1, 1)
|
||||
self.std = 1 / torch.tensor(std).view(1, 16, 1, 1, 1)
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
x = x.unsqueeze(2)
|
||||
x = self.encoder(x)
|
||||
x = self.quant_conv(x)
|
||||
x = x[:, :16]
|
||||
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
|
||||
x = (x - mean) * std
|
||||
x = x.squeeze(2)
|
||||
return x
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = x.unsqueeze(2)
|
||||
mean, std = self.mean.to(dtype=x.dtype, device=x.device), self.std.to(dtype=x.dtype, device=x.device)
|
||||
x = x / std + mean
|
||||
x = self.post_quant_conv(x)
|
||||
x = self.decoder(x)
|
||||
x = x.squeeze(2)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageVAEStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageVAEStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
@@ -50,14 +50,30 @@ class PatchEmbed(torch.nn.Module):
|
||||
return latent + pos_embed
|
||||
|
||||
|
||||
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||
if diffusers_compatible_format:
|
||||
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||
else:
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
|
||||
@@ -45,6 +45,7 @@ def get_timestep_embedding(
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
computation_device = None,
|
||||
align_dtype_to_timestep = False,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
@@ -63,6 +64,8 @@ def get_timestep_embedding(
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent).to(timesteps.device)
|
||||
if align_dtype_to_timestep:
|
||||
emb = emb.to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
@@ -82,12 +85,14 @@ def get_timestep_embedding(
|
||||
|
||||
|
||||
class TemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.computation_device = computation_device
|
||||
self.scale = scale
|
||||
self.align_dtype_to_timestep = align_dtype_to_timestep
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
@@ -96,6 +101,8 @@ class TemporalTimesteps(torch.nn.Module):
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
computation_device=self.computation_device,
|
||||
scale=self.scale,
|
||||
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
Reference in New Issue
Block a user