mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
fix wan i2v train bug
This commit is contained in:
@@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
|||||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||||
|
|
||||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
noise_pred = noise_pred[:, :, 1:]
|
||||||
|
training_target = training_target[:, :, 1:]
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
Reference in New Issue
Block a user