support qwen-image-layered

This commit is contained in:
Artiprocher
2025-12-19 19:06:37 +08:00
parent 11315d7a40
commit c6722b3f56
18 changed files with 417 additions and 27 deletions

View File

@@ -63,6 +63,20 @@ qwen_image_series = [
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
"extra_kwargs": {"image_channels": 4}
},
]
wan_series = [

View File

@@ -13,6 +13,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",

View File

@@ -53,12 +53,14 @@ class ToStr(DataProcessingOperator):
class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True):
def __init__(self, convert_RGB=True, convert_RGBA=False):
self.convert_RGB = convert_RGB
self.convert_RGBA = convert_RGBA
def __call__(self, data: str):
image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB")
if self.convert_RGBA: image = image.convert("RGBA")
return image

View File

@@ -19,7 +19,7 @@ def get_timestep_embedding(
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(timesteps.device)
emb = torch.exp(exponent)
if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
@@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module):
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format:
@@ -87,10 +87,16 @@ class TimestepEmbeddings(torch.nn.Module):
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype):
def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb

View File

@@ -1,4 +1,4 @@
import torch, math
import torch, math, functools
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
@@ -225,6 +225,121 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
video_fhw = [video_fhw]
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenFeedForward(nn.Module):
def __init__(
self,
@@ -437,12 +552,17 @@ class QwenImageDiT(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
use_layer3d_rope: bool = False,
use_additional_t_cond: bool = False,
):
super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
if not use_layer3d_rope:
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
else:
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64, 3072)

View File

@@ -366,6 +366,7 @@ class QwenImageEncoder3d(nn.Module):
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3
):
super().__init__()
self.dim = dim
@@ -381,7 +382,7 @@ class QwenImageEncoder3d(nn.Module):
scale = 1.0
# init block
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = torch.nn.ModuleList([])
@@ -544,6 +545,7 @@ class QwenImageDecoder3d(nn.Module):
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
image_channels=3,
):
super().__init__()
self.dim = dim
@@ -594,7 +596,7 @@ class QwenImageDecoder3d(nn.Module):
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1)
self.gradient_checkpointing = False
@@ -647,6 +649,7 @@ class QwenImageVAE(torch.nn.Module):
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
image_channels: int = 3,
) -> None:
super().__init__()
@@ -655,13 +658,13 @@ class QwenImageVAE(torch.nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels,
)
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels,
)
mean = [

View File

@@ -48,6 +48,7 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(),
QwenImageUnit_EditImageEmbedder(),
QwenImageUnit_LayerInputImageEmbedder(),
QwenImageUnit_ContextImageEmbedder(),
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
@@ -128,6 +129,9 @@ class QwenImagePipeline(BasePipeline):
edit_rope_interpolation: bool = False,
# Qwen-Image-Edit-2511
zero_cond_t: bool = False,
# Qwen-Image-Layered
layer_input_image: Image.Image = None,
layer_num: int = None,
# In-context control
context_image: Image.Image = None,
# Tile
@@ -160,6 +164,8 @@ class QwenImagePipeline(BasePipeline):
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
"context_image": context_image,
"zero_cond_t": zero_cond_t,
"layer_input_image": layer_input_image,
"layer_num": layer_num,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -179,7 +185,10 @@ class QwenImagePipeline(BasePipeline):
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
if layer_num is None:
image = self.vae_output_to_image(image)
else:
image = [self.vae_output_to_image(i, pattern="C H W") for i in image]
self.load_models_to_device([])
return image
@@ -230,12 +239,15 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
class QwenImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
input_params=("height", "width", "seed", "rand_device", "layer_num"),
output_params=("noise",),
)
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num):
if layer_num is None:
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
else:
noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
@@ -252,8 +264,15 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if isinstance(input_image, list):
input_latents = []
for image in input_image:
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride))
input_latents = torch.concat(input_latents, dim=0)
else:
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
@@ -261,6 +280,22 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents}
class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"),
output_params=("layer_input_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride):
if layer_input_image is None:
return {}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"layer_input_latents": latents}
class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self):
@@ -677,6 +712,8 @@ def model_fn_qwen_image(
entity_prompt_emb_mask=None,
entity_masks=None,
edit_latents=None,
layer_input_latents=None,
layer_num=None,
context_latents=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False,
@@ -685,11 +722,16 @@ def model_fn_qwen_image(
zero_cond_t=False,
**kwargs
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
if layer_num is None:
layer_num = 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)]
else:
layer_num = layer_num + 1
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num)
image_seq_len = image.shape[1]
if context_latents is not None:
@@ -701,6 +743,11 @@ def model_fn_qwen_image(
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
image = torch.cat([image] + edit_image, dim=1)
if layer_input_latents is not None:
layer_num = layer_num + 1
img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)]
layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
image = torch.cat([image, layer_input_latents], dim=1)
image = dit.img_in(image)
if zero_cond_t:
@@ -712,7 +759,11 @@ def model_fn_qwen_image(
)
else:
modulate_index = None
conditioning = dit.time_text_embed(timestep, image.dtype)
conditioning = dit.time_text_embed(
timestep,
image.dtype,
addition_t_cond=None if layer_num is None else torch.tensor([0]).to(device=image.device, dtype=torch.long)
)
if entity_prompt_emb is not None:
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
@@ -759,5 +810,5 @@ def model_fn_qwen_image(
image = dit.proj_out(image)
image = image[:, :image_seq_len]
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1)
return latents