diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 5f7c723..e2cab59 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -71,7 +71,10 @@ class LightningModelForT2ILoRA(pl.LightningModule): # Prepare input parameters self.pipe.device = self.device prompt_emb = self.pipe.encode_prompt(text, positive=True) - latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) + if "latents" in batch: + latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device) + else: + latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) noise = torch.randn_like(latents) timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)