mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support direct distill
This commit is contained in:
@@ -150,6 +150,17 @@ class QwenImagePipeline(BasePipeline):
|
||||
return loss
|
||||
|
||||
|
||||
def direct_distill_loss(self, **inputs):
|
||||
self.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(self.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||
inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||
return loss
|
||||
|
||||
|
||||
def _enable_fp8_lora_training(self, dtype):
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
|
||||
from ..models.qwen_image_dit import RMSNorm
|
||||
|
||||
Reference in New Issue
Block a user