mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -140,8 +140,9 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||||
|
|
||||||
|
|||||||
@@ -421,10 +421,12 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def transfer_data_to_device(self, data, device):
|
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||||
for key in data:
|
for key in data:
|
||||||
if isinstance(data[key], torch.Tensor):
|
if isinstance(data[key], torch.Tensor):
|
||||||
data[key] = data[key].to(device)
|
data[key] = data[key].to(device)
|
||||||
|
if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||||
|
data[key] = data[key].to(torch_float_dtype)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -85,8 +85,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
|
|
||||||
def forward(self, data, inputs=None, return_inputs=False):
|
def forward(self, data, inputs=None, return_inputs=False):
|
||||||
# Inputs
|
# Inputs
|
||||||
if inputs is None: inputs = self.forward_preprocess(data)
|
if inputs is None:
|
||||||
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
|
inputs = self.forward_preprocess(data)
|
||||||
|
else:
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
if return_inputs: return inputs
|
if return_inputs: return inputs
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
|
|||||||
Reference in New Issue
Block a user