from .base_pipeline import BasePipeline import torch def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps) min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps) timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) noise = torch.randn_like(inputs["input_latents"]) inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * pipe.scheduler.training_weight(timestep) return loss def DirectDistillLoss(pipe: BasePipeline, **inputs): pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} for progress_id, timestep in enumerate(pipe.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) return loss