Update text_to_image.py

This commit is contained in:
Zhongjie Duan
2025-01-02 14:23:02 +08:00
committed by GitHub
parent 088ea29e6e
commit 3b010043de

View File

@@ -71,7 +71,10 @@ class LightningModelForT2ILoRA(pl.LightningModule):
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)