update noise generate

This commit is contained in:
Qing112
2024-10-21 15:09:21 +08:00
parent 72ed76e89e
commit 747572e62c
8 changed files with 23 additions and 14 deletions

View File

@@ -87,6 +87,7 @@ class SD3ImagePipeline(BasePipeline):
tiled=False,
tile_size=128,
tile_stride=64,
seed=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -101,10 +102,10 @@ class SD3ImagePipeline(BasePipeline):
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])