support kolors! (#106)

This commit is contained in:
Zhongjie Duan
2024-07-11 21:43:45 +08:00
committed by GitHub
parent 2a4709e572
commit 9c6607f78d
20 changed files with 2510 additions and 281 deletions

View File

@@ -216,7 +216,7 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
# Prepare latent tensors
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
@@ -293,6 +293,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image