mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
wan-series
This commit is contained in:
@@ -3,7 +3,10 @@ import torch
|
||||
|
||||
|
||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,))
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps)
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps)
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
|
||||
Reference in New Issue
Block a user