mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
159 lines
8.5 KiB
Python
159 lines
8.5 KiB
Python
from .base_pipeline import BasePipeline
|
|
import torch
|
|
|
|
|
|
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
|
|
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
noise = torch.randn_like(inputs["input_latents"])
|
|
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
|
|
|
if "first_frame_latents" in inputs:
|
|
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
|
|
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
|
|
|
if "first_frame_latents" in inputs:
|
|
noise_pred = noise_pred[:, :, 1:]
|
|
training_target = training_target[:, :, 1:]
|
|
|
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
|
return loss
|
|
|
|
|
|
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
|
|
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
# video
|
|
noise = torch.randn_like(inputs["input_latents"])
|
|
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
|
|
|
# audio
|
|
if inputs.get("audio_input_latents") is not None:
|
|
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
|
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
|
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
|
|
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
|
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
|
|
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
|
if inputs.get("audio_input_latents") is not None:
|
|
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
|
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
|
loss = loss + loss_audio
|
|
return loss
|
|
|
|
|
|
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
|
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
|
pipe.scheduler.training = True
|
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
|
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
|
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
|
return loss
|
|
|
|
|
|
class TrajectoryImitationLoss(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.initialized = False
|
|
|
|
def initialize(self, device):
|
|
import lpips # TODO: remove it
|
|
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
|
self.initialized = True
|
|
|
|
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
|
trajectory = [inputs_shared["latents"].clone()]
|
|
|
|
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
noise_pred = pipe.cfg_guided_model_fn(
|
|
pipe.model_fn, cfg_scale,
|
|
inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
|
|
|
trajectory.append(inputs_shared["latents"].clone())
|
|
return pipe.scheduler.timesteps, trajectory
|
|
|
|
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
|
loss = 0
|
|
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
|
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
|
|
|
noise_pred = pipe.cfg_guided_model_fn(
|
|
pipe.model_fn, cfg_scale,
|
|
inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
|
|
sigma = pipe.scheduler.sigmas[progress_id]
|
|
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
|
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
|
latents_ = trajectory_teacher[-1]
|
|
else:
|
|
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
|
latents_ = trajectory_teacher[progress_id_teacher]
|
|
|
|
denom = sigma_ - sigma
|
|
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
|
|
target = (latents_ - inputs_shared["latents"]) / denom
|
|
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
|
return loss
|
|
|
|
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
|
inputs_shared["latents"] = trajectory_teacher[0]
|
|
pipe.scheduler.set_timesteps(num_inference_steps)
|
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
noise_pred = pipe.cfg_guided_model_fn(
|
|
pipe.model_fn, cfg_scale,
|
|
inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
|
|
|
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
|
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
|
loss = self.loss_fn(image_pred.float(), image_real.float())
|
|
return loss
|
|
|
|
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
|
if not self.initialized:
|
|
self.initialize(pipe.device)
|
|
with torch.no_grad():
|
|
pipe.scheduler.set_timesteps(8)
|
|
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
|
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
|
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
|
loss = loss_1 + loss_2
|
|
return loss
|