diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 1d3c66b..83ff290 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -140,8 +140,9 @@ class QwenImagePipeline(BasePipeline): timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) - inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) - training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep) + noise = torch.randn_like(inputs["input_latents"]) + inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep) noise_pred = self.model_fn(**inputs, timestep=timestep) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 0711176..3262d15 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -421,10 +421,12 @@ class DiffusionTrainingModule(torch.nn.Module): return state_dict - def transfer_data_to_device(self, data, device): + def transfer_data_to_device(self, data, device, torch_float_dtype=None): for key in data: if isinstance(data[key], torch.Tensor): data[key] = data[key].to(device) + if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]: + data[key] = data[key].to(torch_float_dtype) return data diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 1b4faec..f39c11c 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -85,8 +85,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule): def forward(self, data, inputs=None, return_inputs=False): # Inputs - if inputs is None: inputs = self.forward_preprocess(data) - else: inputs = self.transfer_data_to_device(inputs, self.pipe.device) + if inputs is None: + inputs = self.forward_preprocess(data) + else: + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) if return_inputs: return inputs # Loss