From b0abdaffb408727ec4d19bc1208a53c2144a2ebd Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Wed, 17 Sep 2025 20:53:46 +0800 Subject: [PATCH] Qwen image split training Bug Fix (#926) * bugfix --- diffsynth/pipelines/qwen_image.py | 5 +++-- diffsynth/trainers/utils.py | 4 +++- examples/qwen_image/model_training/train.py | 6 ++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 1d3c66b..83ff290 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -140,8 +140,9 @@ class QwenImagePipeline(BasePipeline): timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) - inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) - training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) + noise = torch.randn_like(inputs["input_latents"]) + inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep) noise_pred = self.model_fn(**inputs, timestep=timestep) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 0711176..3262d15 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -421,10 +421,12 @@ class DiffusionTrainingModule(torch.nn.Module): return state_dict - def transfer_data_to_device(self, data, device): + def transfer_data_to_device(self, data, device, torch_float_dtype=None): for key in data: if isinstance(data[key], torch.Tensor): data[key] = data[key].to(device) + if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]: + data[key] = data[key].to(torch_float_dtype) return data diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 1b4faec..f39c11c 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -85,8 +85,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule): def forward(self, data, inputs=None, return_inputs=False): # Inputs - if inputs is None: inputs = self.forward_preprocess(data) - else: inputs = self.transfer_data_to_device(inputs, self.pipe.device) + if inputs is None: + inputs = self.forward_preprocess(data) + else: + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) if return_inputs: return inputs # Loss