mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
@@ -421,10 +421,12 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
return state_dict
|
||||
|
||||
|
||||
def transfer_data_to_device(self, data, device):
|
||||
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||
for key in data:
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
data[key] = data[key].to(device)
|
||||
if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||
data[key] = data[key].to(torch_float_dtype)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user