diff --git a/README.md b/README.md index cf42413..6e398d3 100644 --- a/README.md +++ b/README.md @@ -399,6 +399,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| diff --git a/README_zh.md b/README_zh.md index a71d8c8..13799b5 100644 --- a/README_zh.md +++ b/README_zh.md @@ -399,6 +399,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/ |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index dca078a..08ec023 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -63,6 +63,20 @@ qwen_image_series = [ "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel", "extra_kwargs": {"compress_dim": 64, "use_residual": False} }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors") + "model_hash": "8dc8cda05de16c73afa755e2c1ce2839", + "model_name": "qwen_image_dit", + "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT", + "extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True} + }, + { + # Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "44b39ddc499e027cfb24f7878d7416b9", + "model_name": "qwen_image_vae", + "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE", + "extra_kwargs": {"image_channels": 4} + }, ] wan_series = [ diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 958dad4..5f1b595 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -13,6 +13,7 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "diffsynth.models.qwen_image_dit.QwenImageDiT": { "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", }, "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": { "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py index c14944e..b7a0e7e 100644 --- a/diffsynth/core/data/operators.py +++ b/diffsynth/core/data/operators.py @@ -53,12 +53,14 @@ class ToStr(DataProcessingOperator): class LoadImage(DataProcessingOperator): - def __init__(self, convert_RGB=True): + def __init__(self, convert_RGB=True, convert_RGBA=False): self.convert_RGB = convert_RGB + self.convert_RGBA = convert_RGBA def __call__(self, data: str): image = Image.open(data) if self.convert_RGB: image = image.convert("RGB") + if self.convert_RGBA: image = image.convert("RGBA") return image diff --git a/diffsynth/models/general_modules.py b/diffsynth/models/general_modules.py index 216247c..3266715 100644 --- a/diffsynth/models/general_modules.py +++ b/diffsynth/models/general_modules.py @@ -19,7 +19,7 @@ def get_timestep_embedding( ) exponent = exponent / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent).to(timesteps.device) + emb = torch.exp(exponent) if align_dtype_to_timestep: emb = emb.to(timesteps.dtype) emb = timesteps[:, None].float() * emb[None, :] @@ -78,7 +78,7 @@ class DiffusersCompatibleTimestepProj(torch.nn.Module): class TimestepEmbeddings(torch.nn.Module): - def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False): + def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False): super().__init__() self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep) if diffusers_compatible_format: @@ -87,10 +87,16 @@ class TimestepEmbeddings(torch.nn.Module): self.timestep_embedder = torch.nn.Sequential( torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) ) + if use_additional_t_cond: + self.addition_t_embedding = torch.nn.Embedding(2, dim_out) - def forward(self, timestep, dtype): + def forward(self, timestep, dtype, addition_t_cond=None): time_emb = self.time_proj(timestep).to(dtype) time_emb = self.timestep_embedder(time_emb) + if addition_t_cond is not None: + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=dtype) + time_emb = time_emb + addition_t_emb return time_emb diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index 1720f67..2dd5143 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -1,4 +1,4 @@ -import torch, math +import torch, math, functools import torch.nn as nn from typing import Tuple, Optional, Union, List from einops import rearrange @@ -225,6 +225,121 @@ class QwenEmbedRope(nn.Module): return vid_freqs, txt_freqs +class QwenEmbedLayer3DRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + video_fhw = [video_fhw] + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + layer_num = len(video_fhw) - 1 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + if idx != layer_num: + video_freq = self._compute_video_freqs(frame, height, width, idx) + else: + ### For the condition image, we set the layer index to -1 + video_freq = self._compute_condition_freqs(frame, height, width) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_vid_index = max(max_vid_index, layer_num) + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + @functools.lru_cache(maxsize=None) + def _compute_condition_freqs(self, frame, height, width): + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + class QwenFeedForward(nn.Module): def __init__( self, @@ -437,12 +552,17 @@ class QwenImageDiT(torch.nn.Module): def __init__( self, num_layers: int = 60, + use_layer3d_rope: bool = False, + use_additional_t_cond: bool = False, ): super().__init__() - self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + if not use_layer3d_rope: + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) + else: + self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True) - self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True) + self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond) self.txt_norm = RMSNorm(3584, eps=1e-6) self.img_in = nn.Linear(64, 3072) diff --git a/diffsynth/models/qwen_image_vae.py b/diffsynth/models/qwen_image_vae.py index cb04713..2845354 100644 --- a/diffsynth/models/qwen_image_vae.py +++ b/diffsynth/models/qwen_image_vae.py @@ -366,6 +366,7 @@ class QwenImageEncoder3d(nn.Module): temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", + image_channels=3 ): super().__init__() self.dim = dim @@ -381,7 +382,7 @@ class QwenImageEncoder3d(nn.Module): scale = 1.0 # init block - self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + self.conv_in = QwenImageCausalConv3d(image_channels, dims[0], 3, padding=1) # downsample blocks self.down_blocks = torch.nn.ModuleList([]) @@ -544,6 +545,7 @@ class QwenImageDecoder3d(nn.Module): temperal_upsample=[False, True, True], dropout=0.0, non_linearity: str = "silu", + image_channels=3, ): super().__init__() self.dim = dim @@ -594,7 +596,7 @@ class QwenImageDecoder3d(nn.Module): # output blocks self.norm_out = QwenImageRMS_norm(out_dim, images=False) - self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + self.conv_out = QwenImageCausalConv3d(out_dim, image_channels, 3, padding=1) self.gradient_checkpointing = False @@ -647,6 +649,7 @@ class QwenImageVAE(torch.nn.Module): attn_scales: List[float] = [], temperal_downsample: List[bool] = [False, True, True], dropout: float = 0.0, + image_channels: int = 3, ) -> None: super().__init__() @@ -655,13 +658,13 @@ class QwenImageVAE(torch.nn.Module): self.temperal_upsample = temperal_downsample[::-1] self.encoder = QwenImageEncoder3d( - base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, image_channels=image_channels, ) self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) self.decoder = QwenImageDecoder3d( - base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, image_channels=image_channels, ) mean = [ diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 3956f41..ab0fa8b 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -48,6 +48,7 @@ class QwenImagePipeline(BasePipeline): QwenImageUnit_InputImageEmbedder(), QwenImageUnit_Inpaint(), QwenImageUnit_EditImageEmbedder(), + QwenImageUnit_LayerInputImageEmbedder(), QwenImageUnit_ContextImageEmbedder(), QwenImageUnit_PromptEmbedder(), QwenImageUnit_EntityControl(), @@ -128,6 +129,9 @@ class QwenImagePipeline(BasePipeline): edit_rope_interpolation: bool = False, # Qwen-Image-Edit-2511 zero_cond_t: bool = False, + # Qwen-Image-Layered + layer_input_image: Image.Image = None, + layer_num: int = None, # In-context control context_image: Image.Image = None, # Tile @@ -160,6 +164,8 @@ class QwenImagePipeline(BasePipeline): "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, "context_image": context_image, "zero_cond_t": zero_cond_t, + "layer_input_image": layer_input_image, + "layer_num": layer_num, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -179,7 +185,10 @@ class QwenImagePipeline(BasePipeline): # Decode self.load_models_to_device(['vae']) image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - image = self.vae_output_to_image(image) + if layer_num is None: + image = self.vae_output_to_image(image) + else: + image = [self.vae_output_to_image(i, pattern="C H W") for i in image] self.load_models_to_device([]) return image @@ -230,12 +239,15 @@ class QwenImageUnit_ShapeChecker(PipelineUnit): class QwenImageUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "seed", "rand_device"), + input_params=("height", "width", "seed", "rand_device", "layer_num"), output_params=("noise",), ) - def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device): - noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num): + if layer_num is None: + noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + else: + noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) return {"noise": noise} @@ -252,8 +264,15 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit): if input_image is None: return {"latents": noise, "input_latents": None} pipe.load_models_to_device(['vae']) - image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) - input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if isinstance(input_image, list): + input_latents = [] + for image in input_image: + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)) + input_latents = torch.concat(input_latents, dim=0) + else: + image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) if pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: @@ -261,6 +280,22 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit): return {"latents": latents, "input_latents": input_latents} +class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"), + output_params=("layer_input_latents",), + onload_model_names=("vae",) + ) + + def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride): + if layer_input_image is None: + return {} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype) + latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return {"layer_input_latents": latents} + class QwenImageUnit_Inpaint(PipelineUnit): def __init__(self): @@ -677,6 +712,8 @@ def model_fn_qwen_image( entity_prompt_emb_mask=None, entity_masks=None, edit_latents=None, + layer_input_latents=None, + layer_num=None, context_latents=None, enable_fp8_attention=False, use_gradient_checkpointing=False, @@ -685,11 +722,16 @@ def model_fn_qwen_image( zero_cond_t=False, **kwargs ): - img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] + if layer_num is None: + layer_num = 1 + img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] + else: + layer_num = layer_num + 1 + img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() timestep = timestep / 1000 - image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num) image_seq_len = image.shape[1] if context_latents is not None: @@ -701,6 +743,11 @@ def model_fn_qwen_image( img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list] edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list] image = torch.cat([image] + edit_image, dim=1) + if layer_input_latents is not None: + layer_num = layer_num + 1 + img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)] + layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + image = torch.cat([image, layer_input_latents], dim=1) image = dit.img_in(image) if zero_cond_t: @@ -712,7 +759,11 @@ def model_fn_qwen_image( ) else: modulate_index = None - conditioning = dit.time_text_embed(timestep, image.dtype) + conditioning = dit.time_text_embed( + timestep, + image.dtype, + addition_t_cond=None if layer_num is None else torch.tensor([0]).to(device=image.device, dtype=torch.long) + ) if entity_prompt_emb is not None: text, image_rotary_emb, attention_mask = dit.process_entity_masks( @@ -759,5 +810,5 @@ def model_fn_qwen_image( image = dit.proj_out(image) image = image[:, :image_seq_len] - latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) + latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1) return latents diff --git a/docs/en/Model_Details/Qwen-Image.md b/docs/en/Model_Details/Qwen-Image.md index 41676fa..9062d2a 100644 --- a/docs/en/Model_Details/Qwen-Image.md +++ b/docs/en/Model_Details/Qwen-Image.md @@ -84,6 +84,7 @@ graph LR; | [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) | | [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) | |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| | [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) | | [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) | diff --git a/docs/zh/Model_Details/Qwen-Image.md b/docs/zh/Model_Details/Qwen-Image.md index 02d44aa..87ccba0 100644 --- a/docs/zh/Model_Details/Qwen-Image.md +++ b/docs/zh/Model_Details/Qwen-Image.md @@ -84,6 +84,7 @@ graph LR; |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| |[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)| +|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)| |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| diff --git a/examples/qwen_image/model_inference/Qwen-Image-Layered.py b/examples/qwen_image/model_inference/Qwen-Image-Layered.py new file mode 100644 index 0000000..95aa475 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Layered.py @@ -0,0 +1,36 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +dataset_snapshot_download( + "DiffSynth-Studio/example_image_dataset", + allow_patterns="layer/image.png", + local_dir="data/example_image_dataset" +) + +# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt. +prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.' +# Height and width should be consistent with input_image and be divided evenly by 16 +input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480)) +images = pipe( + prompt, + seed=1, num_inference_steps=50, + height=480, width=864, + layer_input_image=input_image, layer_num=3, +) +for i, image in enumerate(images): + if i == 0: continue # The first image is the input image. + image.save(f"image_{i}.png") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py new file mode 100644 index 0000000..4f3438d --- /dev/null +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py @@ -0,0 +1,46 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download +from PIL import Image +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), +) + +dataset_snapshot_download( + "DiffSynth-Studio/example_image_dataset", + allow_patterns="layer/image.png", + local_dir="data/example_image_dataset" +) + +# Prompt should be provided to the pipeline. Our pipeline will not generate the prompt. +prompt = 'A cheerful child with brown hair is waving enthusiastically under a bright blue sky filled with colorful confetti and balloons. The word "HELLO!" is prominently displayed in bold red letters above the child, while "Have a Great Day!" appears in elegant cursive at the bottom right corner. The scene is vibrant and festive, with a mix of pastel colors and dynamic shapes creating a joyful atmosphere.' +# Height and width should be consistent with input_image and be divided evenly by 16 +input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480)) +images = pipe( + prompt, + seed=1, num_inference_steps=50, + height=480, width=864, + layer_input_image=input_image, layer_num=3, +) +for i, image in enumerate(images): + if i == 0: continue # The first image is the input image. + image.save(f"image_{i}.png") diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh b/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh new file mode 100644 index 0000000..91cdb5e --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh @@ -0,0 +1,18 @@ +# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer + +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset/layer \ + --dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \ + --data_file_keys "image,layer_input_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Layered_full" \ + --trainable_models "dit" \ + --extra_inputs "layer_num,layer_input_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh new file mode 100644 index 0000000..75a23f9 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh @@ -0,0 +1,20 @@ +# Example Dataset: https://modelscope.cn/datasets/DiffSynth-Studio/example_image_dataset/tree/master/layer + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset/layer \ + --dataset_metadata_path data/example_image_dataset/layer/metadata_layered.json \ + --data_file_keys "image,layer_input_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Layered:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image-Layered:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Layered_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --extra_inputs "layer_num,layer_input_image" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 1c24f1d..8f38d04 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -2,6 +2,7 @@ import torch, os, argparse, accelerate from diffsynth.core import UnifiedDataset from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.diffusion import * +from diffsynth.core.data.operators import * os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -58,11 +59,6 @@ class QwenImageTrainingModule(DiffusionTrainingModule): inputs_posi = {"prompt": data["prompt"]} inputs_nega = {"negative_prompt": ""} inputs_shared = { - # Assume you are using this pipeline for inference, - # please fill in the input parameters. - "input_image": data["image"], - "height": data["image"].size[1], - "width": data["image"].size[0], # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, @@ -72,6 +68,20 @@ class QwenImageTrainingModule(DiffusionTrainingModule): "edit_image_auto_resize": True, "zero_cond_t": self.zero_cond_t, } + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + if isinstance(data["image"], list): + inputs_shared.update({ + "input_image": data["image"], + "height": data["image"][0].size[1], + "width": data["image"][0].size[0], + }) + else: + inputs_shared.update({ + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + }) inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) return inputs_shared, inputs_posi, inputs_nega @@ -113,7 +123,15 @@ if __name__ == "__main__": width=args.width, height_division_factor=16, width_division_factor=16, - ) + ), + special_operator_map={ + # Qwen-Image-Layered + "layer_input_image": ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16), + "image": RouteByType(operator_map=[ + (str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16)), + (list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 16, 16))), + ]) + } ) model = QwenImageTrainingModule( model_paths=args.model_paths, diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py new file mode 100644 index 0000000..7318d4f --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py @@ -0,0 +1,26 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("models/train/Qwen-Image-Layered_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +prompt = "a poster" +input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480)) +image = pipe( + prompt, seed=0, + height=480, width=864, + layer_input_image=input_image, layer_num=3, +) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py new file mode 100644 index 0000000..4a3949a --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py @@ -0,0 +1,25 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Layered_lora/epoch-4.safetensors") +prompt = "a poster" +input_image = Image.open("data/example_image_dataset/layer/image.png").convert("RGBA").resize((864, 480)) +image = pipe( + prompt, seed=0, + height=480, width=864, + layer_input_image=input_image, layer_num=3, +) +image.save("image.jpg")