mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx-2 training
This commit is contained in:
@@ -28,6 +28,36 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
return loss
|
||||
|
||||
|
||||
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
# video
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
# audio
|
||||
if inputs.get("audio_input_latents") is not None:
|
||||
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
||||
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
||||
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
||||
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
if inputs.get("audio_input_latents") is not None:
|
||||
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
||||
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
||||
loss = loss + loss_audio
|
||||
return loss
|
||||
|
||||
|
||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
pipe.scheduler.training = True
|
||||
|
||||
Reference in New Issue
Block a user