mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
z image distill
This commit is contained in:
@@ -101,7 +101,7 @@ class FlowMatchScheduler():
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
shift = 3 if shift is None else shift
|
||||
@@ -110,6 +110,11 @@ class FlowMatchScheduler():
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
if target_timesteps is not None:
|
||||
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
||||
for timestep in target_timesteps:
|
||||
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
@@ -118,6 +123,10 @@ class FlowMatchScheduler():
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||
if len(self.timesteps) != 1000:
|
||||
# This is an empirical formula.
|
||||
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
|
||||
@@ -31,3 +31,89 @@ def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||
return loss
|
||||
|
||||
|
||||
class TrajectoryImitationLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, device):
|
||||
import lpips # TODO: remove it
|
||||
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
||||
self.initialized = True
|
||||
|
||||
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
trajectory = [inputs_shared["latents"].clone()]
|
||||
|
||||
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||
|
||||
trajectory.append(inputs_shared["latents"].clone())
|
||||
return pipe.scheduler.timesteps, trajectory
|
||||
|
||||
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
loss = 0
|
||||
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
||||
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
||||
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
|
||||
sigma = pipe.scheduler.sigmas[progress_id]
|
||||
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
||||
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
||||
latents_ = trajectory_teacher[-1]
|
||||
else:
|
||||
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||
latents_ = trajectory_teacher[progress_id_teacher]
|
||||
|
||||
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
||||
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
inputs_shared["latents"] = trajectory_teacher[0]
|
||||
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||
|
||||
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
||||
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
||||
loss = self.loss_fn(image_pred.float(), image_real.float())
|
||||
return loss
|
||||
|
||||
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if not self.initialized:
|
||||
self.initialize(pipe.device)
|
||||
with torch.no_grad():
|
||||
pipe.scheduler.set_timesteps(8)
|
||||
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
||||
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||
loss = loss_1 + loss_2
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user