Qwen image split training Bug Fix (#926)

* bugfix
This commit is contained in:
Zhongjie Duan
2025-09-17 20:53:46 +08:00
committed by GitHub
parent e9f29bc402
commit b0abdaffb4
3 changed files with 10 additions and 5 deletions

View File

@@ -140,8 +140,9 @@ class QwenImagePipeline(BasePipeline):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)

View File

@@ -421,10 +421,12 @@ class DiffusionTrainingModule(torch.nn.Module):
return state_dict
def transfer_data_to_device(self, data, device):
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
for key in data:
if isinstance(data[key], torch.Tensor):
data[key] = data[key].to(device)
if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]:
data[key] = data[key].to(torch_float_dtype)
return data