From d93de98a21bcf9940ab32792af890feb5ac75ed1 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 18 Aug 2025 17:31:18 +0800 Subject: [PATCH] fix qwen_rope --- diffsynth/models/qwen_image_dit.py | 69 +++++++++---------- diffsynth/pipelines/qwen_image.py | 5 +- diffsynth/trainers/utils.py | 1 - examples/qwen_image/README.md | 1 - examples/qwen_image/README_zh.md | 1 - .../model_training/full/Qwen-Image-Edit.sh | 1 - .../model_training/lora/Qwen-Image-Edit.sh | 1 - examples/qwen_image/model_training/train.py | 11 +-- 8 files changed, 36 insertions(+), 54 deletions(-) diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index 137e4cd..6ff216b 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -127,49 +127,42 @@ class QwenEmbedRope(nn.Module): self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) - if isinstance(video_fhw, list): - video_fhw = video_fhw[0] - frame, height, width = video_fhw - rope_key = f"{frame}_{height}_{width}" + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs.append(self.rope_cache[rope_key]) - if rope_key not in self.rope_cache: - seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: - freqs_height = torch.cat( - [ - freqs_neg[1][-(height - height//2):], - freqs_pos[1][:height//2] - ], - dim=0 - ) - freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = torch.cat( - [ - freqs_neg[2][-(width - width//2):], - freqs_pos[2][:width//2] - ], - dim=0 - ) - freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) - + max_vid_index = max(height // 2, width // 2, max_vid_index) else: - freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) - - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) - self.rope_cache[rope_key] = freqs.clone().contiguous() - vid_freqs = self.rope_cache[rope_key] - - if self.scale_rope: - max_vid_index = max(height // 2, width // 2) - else: - max_vid_index = max(height, width) + max_vid_index = max(height, width, max_vid_index) max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index: max_vid_index + max_len, ...] + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + return vid_freqs, txt_freqs diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index ff02775..ced70f6 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -565,7 +565,6 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit): def process(self, pipe: QwenImagePipeline, edit_image, height, width, tiled, tile_size, tile_stride): if edit_image is None: return {} - edit_image = edit_image.resize((width, height)) pipe.load_models_to_device(['vae']) edit_image = pipe.preprocess_image(edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) @@ -601,8 +600,8 @@ def model_fn_qwen_image( image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) if edit_latents is not None: - img_shapes[0] = (img_shapes[0][0] + edit_latents.shape[0], img_shapes[0][1], img_shapes[0][2]) - edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)] + edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2) image_seq_len = image.shape[1] image = torch.cat([image, edit_image], dim=1) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 4b1aad6..8358e55 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -552,6 +552,5 @@ def qwen_image_parser(): parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") - parser.add_argument("--edit_model", default=False, action="store_true", help="Whether to use Qwen-Image-Edit. If True, the model will be used for image editing.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") return parser diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index 570d9cf..d2f371d 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -236,7 +236,6 @@ The script includes the following parameters: * `--model_paths`: Model paths to load. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. * `--tokenizer_path`: Tokenizer path. Leave empty to auto-download. - * `--edit_model`: Whether to use Qwen-Image-Edit. If True, the model will be used for image editing. * `--processor_path`: Path to the processor of Qwen-Image-Edit. Leave empty to auto-download. * Training * `--learning_rate`: Learning rate. diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index 26965b8..aa2bd0c 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -236,7 +236,6 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--tokenizer_path`: tokenizer 路径,留空将会自动下载。 - * `--edit_model`:是否使用 Qwen-Image-Edit。若为 True,则将使用该模型进行图像编辑。 * `--processor_path`:Qwen-Image-Edit 的 processor 路径。留空则自动下载。 * 训练 * `--learning_rate`: 学习率。 diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh index fe99f01..bf9fd89 100644 --- a/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh +++ b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh @@ -1,5 +1,4 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ - --edit_model \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ --max_pixels 1048576 \ diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh index ddde13f..d12363f 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh @@ -1,5 +1,4 @@ accelerate launch examples/qwen_image/model_training/train.py \ - --edit_model \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ --max_pixels 1048576 \ diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 5239797..b0ba69e 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -11,7 +11,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): def __init__( self, model_paths=None, model_id_with_origin_paths=None, - tokenizer_path=None, processor_path=None, edit_model=False, + tokenizer_path=None, processor_path=None, trainable_models=None, lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, use_gradient_checkpointing=True, @@ -28,12 +28,8 @@ class QwenImageTrainingModule(DiffusionTrainingModule): model_id_with_origin_paths = model_id_with_origin_paths.split(",") model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] - if edit_model: - tokenizer_config = None - processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) - else: - tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) - processor_config = None + tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) # Reset training scheduler (do it in each training step) @@ -120,7 +116,6 @@ if __name__ == "__main__": model_id_with_origin_paths=args.model_id_with_origin_paths, tokenizer_path=args.tokenizer_path, processor_path=args.processor_path, - edit_model=args.edit_model, trainable_models=args.trainable_models, lora_base_model=args.lora_base_model, lora_target_modules=args.lora_target_modules,