initial version

This commit is contained in:
Artiprocher
2023-12-08 01:03:30 +08:00
parent 5073cf6938
commit b459784171
30 changed files with 193803 additions and 1 deletions

View File

@@ -0,0 +1,75 @@
from ..models import ModelManager
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDPipeline(torch.nn.Module):
def __init__(self):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
@torch.no_grad()
def __call__(
self,
model_manager: ModelManager,
prompter: SDPrompter,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
init_image=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
prompt_emb = prompter.encode_prompt(model_manager.text_encoder, prompt, clip_skip=clip_skip, device=model_manager.device)
negative_prompt_emb = prompter.encode_prompt(model_manager.text_encoder, negative_prompt, clip_skip=clip_skip, device=model_manager.device)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if init_image is not None:
image = self.preprocess_image(init_image).to(device=model_manager.device, dtype=model_manager.torch_type)
latents = model_manager.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
# Classifier-free guidance
noise_pred_cond = model_manager.unet(latents, timestep, prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise_pred_uncond = model_manager.unet(latents, timestep, negative_prompt_emb, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image