From 3b010043def28bfaeae5e3ef298ccaa0d89924f6 Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:23:02 +0800 Subject: [PATCH] Update text_to_image.py --- diffsynth/trainers/text_to_image.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 5f7c723..e2cab59 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -71,7 +71,10 @@ class LightningModelForT2ILoRA(pl.LightningModule): # Prepare input parameters self.pipe.device = self.device prompt_emb = self.pipe.encode_prompt(text, positive=True) - latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) + if "latents" in batch: + latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device) + else: + latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device)) noise = torch.randn_like(latents) timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)