Qwen image split training Bug Fix (#926)

* bugfix
This commit is contained in:
Zhongjie Duan
2025-09-17 20:53:46 +08:00
committed by GitHub
parent e9f29bc402
commit b0abdaffb4
3 changed files with 10 additions and 5 deletions

View File

@@ -85,8 +85,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
def forward(self, data, inputs=None, return_inputs=False):
# Inputs
if inputs is None: inputs = self.forward_preprocess(data)
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
if inputs is None:
inputs = self.forward_preprocess(data)
else:
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
if return_inputs: return inputs
# Loss