mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support qwen-image-layered
This commit is contained in:
@@ -19,7 +19,7 @@ def get_timestep_embedding(
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent).to(timesteps.device)
|
||||
emb = torch.exp(exponent)
|
||||
if align_dtype_to_timestep:
|
||||
emb = emb.to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
@@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||
if diffusers_compatible_format:
|
||||
@@ -87,10 +87,16 @@ class TimestepEmbeddings(torch.nn.Module):
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
if use_additional_t_cond:
|
||||
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
def forward(self, timestep, dtype, addition_t_cond=None):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
if addition_t_cond is not None:
|
||||
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||
addition_t_emb = addition_t_emb.to(dtype=dtype)
|
||||
time_emb = time_emb + addition_t_emb
|
||||
return time_emb
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, math
|
||||
import torch, math, functools
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional, Union, List
|
||||
from einops import rearrange
|
||||
@@ -225,6 +225,121 @@ class QwenEmbedRope(nn.Module):
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class QwenEmbedLayer3DRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
video_fhw = [video_fhw]
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
if not isinstance(video_fhw, list):
|
||||
video_fhw = [video_fhw]
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
layer_num = len(video_fhw) - 1
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
if idx != layer_num:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
else:
|
||||
### For the condition image, we set the layer index to -1
|
||||
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||
video_freq = video_freq.to(device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class QwenFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -437,12 +552,17 @@ class QwenImageDiT(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
use_layer3d_rope: bool = False,
|
||||
use_additional_t_cond: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
if not use_layer3d_rope:
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
else:
|
||||
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
|
||||
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
||||
|
||||
self.img_in = nn.Linear(64, 3072)
|
||||
|
||||
@@ -366,6 +366,7 @@ class QwenImageEncoder3d(nn.Module):
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
image_channels=3
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -381,7 +382,7 @@ class QwenImageEncoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = torch.nn.ModuleList([])
|
||||
@@ -544,6 +545,7 @@ class QwenImageDecoder3d(nn.Module):
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
image_channels=3,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -594,7 +596,7 @@ class QwenImageDecoder3d(nn.Module):
|
||||
|
||||
# output blocks
|
||||
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -647,6 +649,7 @@ class QwenImageVAE(torch.nn.Module):
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
image_channels: int = 3,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -655,13 +658,13 @@ class QwenImageVAE(torch.nn.Module):
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
self.encoder = QwenImageEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
|
||||
)
|
||||
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = QwenImageDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
|
||||
)
|
||||
|
||||
mean = [
|
||||
|
||||
Reference in New Issue
Block a user