compatibility update

This commit is contained in:
Artiprocher
2023-12-23 20:13:41 +08:00
parent b30d0fa412
commit 66b3e995c2
27 changed files with 1051 additions and 398 deletions

View File

@@ -36,10 +36,17 @@ class MultiControlNetManager:
], dim=0)
return processed_image
def __call__(self, sample, timestep, encoder_hidden_states, conditionings):
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32
):
res_stack = None
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
res_stack_ = model(sample, timestep, encoder_hidden_states, conditioning)
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:
res_stack = res_stack_

View File

@@ -12,7 +12,7 @@ Processor_id: TypeAlias = Literal[
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=512):
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
@@ -44,7 +44,8 @@ class Annotator:
else:
kwargs = {}
if self.processor is not None:
image = self.processor(image, detect_resolution=self.detect_resolution, image_resolution=min(width, height), **kwargs)
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
image = image.resize((width, height))
return image

View File

@@ -15,6 +15,8 @@ from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from transformers import AutoModelForCausalLM
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
@@ -24,12 +26,19 @@ class ModelManager:
self.model_path = {}
self.textual_inversion_dict = {}
def is_beautiful_prompt(self, state_dict):
param_name = "transformer.h.9.self_attention.query_key_value.weight"
return param_name in state_dict
def is_stabe_diffusion_xl(self, state_dict):
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
return param_name in state_dict
def is_stable_diffusion(self, state_dict):
return True
if self.is_stabe_diffusion_xl(state_dict):
return False
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
return param_name in state_dict
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
@@ -74,7 +83,6 @@ class ModelManager:
"unet": SDXLUNet,
"vae_decoder": SDXLVAEDecoder,
"vae_encoder": SDXLVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
@@ -109,6 +117,15 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_beautiful_prompt(self, state_dict, file_path=""):
component = "beautiful_prompt"
model_folder = os.path.dirname(file_path)
model = AutoModelForCausalLM.from_pretrained(
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
).to(self.device).eval()
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -144,6 +161,8 @@ class ModelManager:
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
def load_models(self, file_path_list):
for file_path in file_path_list:

View File

@@ -41,7 +41,7 @@ class Attention(torch.nn.Module):
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)

View File

@@ -1,5 +1,6 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
from .tiler import TileWorker
class ControlNetConditioningLayer(torch.nn.Module):
@@ -92,20 +93,37 @@ class SDControlNet(torch.nn.Module):
self.global_pool = global_pool
def forward(self, sample, timestep, encoder_hidden_states, conditioning):
def forward(
self,
sample, timestep, encoder_hidden_states, conditioning,
tiled=False, tile_size=64, tile_stride=32,
):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if tiled and not isinstance(block, PushBlock):
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
# 4. ControlNet blocks
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]

View File

@@ -279,31 +279,19 @@ class SDUNet(torch.nn.Module):
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, additional_res_stack=None, **kwargs):
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if additional_res_stack is not None and i==31:
hidden_states += additional_res_stack.pop()
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
additional_res_stack = None
if tiled:
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
block, hidden_states, time_emb, text_emb, res_stack,
height, width, tile_size, tile_stride
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 4. output
hidden_states = self.conv_norm_out(hidden_states)
@@ -312,23 +300,6 @@ class SDUNet(torch.nn.Module):
return hidden_states
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
resize_scale = inter_height / height
hidden_states = Tiler()(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
inter_device=hidden_states.device,
inter_dtype=hidden_states.dtype
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
return hidden_states, time_emb, text_emb, res_stack
def state_dict_converter(self):
return SDUNetStateDictConverter()

View File

@@ -1,7 +1,7 @@
import torch
from .attention import Attention
from .sd_unet import ResnetBlock, UpSampler
from .tiler import Tiler
from .tiler import TileWorker
class VAEAttentionBlock(torch.nn.Module):
@@ -79,11 +79,13 @@ class SDVAEDecoder(torch.nn.Module):
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = Tiler()(
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states

View File

@@ -1,7 +1,7 @@
import torch
from .sd_unet import ResnetBlock, DownSampler
from .sd_vae_decoder import VAEAttentionBlock
from .tiler import Tiler
from .tiler import TileWorker
class SDVAEEncoder(torch.nn.Module):
@@ -38,11 +38,13 @@ class SDVAEEncoder(torch.nn.Module):
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = Tiler()(
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states

View File

@@ -1,4 +1,5 @@
import torch
from einops import rearrange, repeat
class Tiler(torch.nn.Module):
@@ -70,6 +71,106 @@ class Tiler(torch.nn.Module):
return x
class TileWorker:
def __init__(self):
pass
def mask(self, height, width, border_width):
# Create a mask with shape (height, width).
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
x = torch.arange(height).repeat(width, 1).T
y = torch.arange(width).repeat(height, 1)
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
mask = (mask / border_width).clip(0, 1)
return mask
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
batch_size, channel, _, _ = model_input.shape
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
unfold_operator = torch.nn.Unfold(
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
model_input = unfold_operator(model_input)
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
return model_input
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
# Call y=forward_fn(x) for each tile
tile_num = model_input.shape[-1]
model_output_stack = []
for tile_id in range(0, tile_num, tile_batch_size):
# process input
tile_id_ = min(tile_id + tile_batch_size, tile_num)
x = model_input[:, :, :, :, tile_id: tile_id_]
x = x.to(device=inference_device, dtype=inference_dtype)
x = rearrange(x, "b c h w n -> (n b) c h w")
# process output
y = forward_fn(x)
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
y = y.to(device=tile_device, dtype=tile_dtype)
model_output_stack.append(y)
model_output = torch.concat(model_output_stack, dim=-1)
return model_output
def io_scale(self, model_output, tile_size):
# Determine the size modification happend in forward_fn
# We only consider the same scale on height and width.
io_scale = model_output.shape[2] / tile_size
return io_scale
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
# The reversed function of tile
mask = self.mask(tile_size, tile_size, border_width)
mask = mask.to(device=tile_device, dtype=tile_dtype)
mask = rearrange(mask, "h w -> 1 1 h w 1")
model_output = model_output * mask
fold_operator = torch.nn.Fold(
output_size=(height, width),
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
model_output = fold_operator(model_output) / fold_operator(mask)
return model_output
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
inference_device, inference_dtype = model_input.device, model_input.dtype
height, width = model_input.shape[2], model_input.shape[3]
border_width = int(tile_stride*0.5) if border_width is None else border_width
# tile
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
# inference
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
# resize
io_scale = self.io_scale(model_output, tile_size)
height, width = int(height*io_scale), int(width*io_scale)
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
border_width = int(border_width*io_scale)
# untile
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
# Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output

View File

@@ -1,3 +1,3 @@
from .stable_diffusion import SDPipeline
from .stable_diffusion_xl import SDXLPipeline
from .stable_diffusion import SDImagePipeline
from .stable_diffusion_xl import SDXLImagePipeline
from .stable_diffusion_video import SDVideoPipeline

View File

@@ -0,0 +1,113 @@
import torch
from ..models import SDUNet, SDMotionModel
from ..models.sd_unet import PushBlock, PopBlock
from ..models.tiler import TileWorker
from ..controlnets import MultiControlNetManager
def lets_dance(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
tiled=False,
tile_size=64,
tile_stride=32,
device = "cuda",
vram_limit_level = 0,
):
# 1. ControlNet
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
# I leave it here because I intend to do something interesting on the ControlNets.
controlnet_insert_block_id = 30
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
# 4. blocks
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
if tiled:
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0],
hidden_states_input[batch_id: batch_id_],
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
motion_module_id = motion_modules.call_block_id[block_id]
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
if vram_limit_level>=1:
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
else:
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)
hidden_states = unet.conv_act(hidden_states)
hidden_states = unet.conv_out(hidden_states)
return hidden_states

View File

@@ -1,14 +1,16 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder
from ..controlnets.controlnet_unit import MultiControlNetManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDPipeline(torch.nn.Module):
class SDImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
@@ -23,6 +25,7 @@ class SDPipeline(torch.nn.Module):
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.unet = model_manager.unet
@@ -31,13 +34,48 @@ class SDPipeline(torch.nn.Module):
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, controlnet_units=[]):
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id),
model_manager.get_model_with_model_path(config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
pipe = SDImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
@@ -45,7 +83,7 @@ class SDPipeline(torch.nn.Module):
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
init_image=None,
input_image=None,
controlnet_image=None,
denoising_strength=1.0,
height=512,
@@ -57,48 +95,43 @@ class SDPipeline(torch.nn.Module):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device)
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if init_image is not None:
image = self.preprocess_image(init_image).to(device=self.device, dtype=self.torch_dtype)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
# Prepare ControlNets
if controlnet_image is not None:
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# ControlNet
if controlnet_image is not None:
additional_res_stack_posi = self.controlnet(latents, timestep, prompt_emb_posi, controlnet_image)
additional_res_stack_nega = self.controlnet(latents, timestep, prompt_emb_nega, controlnet_image)
else:
additional_res_stack_posi = None
additional_res_stack_nega = None
# Classifier-free guidance
noise_pred_posi = self.unet(
latents, timestep, prompt_emb_posi,
additional_res_stack=additional_res_stack_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
noise_pred_posi = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
device=self.device, vram_limit_level=0
)
noise_pred_nega = self.unet(
latents, timestep, prompt_emb_nega,
additional_res_stack=additional_res_stack_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
device=self.device, vram_limit_level=0
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
@@ -110,8 +143,6 @@ class SDPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -1,8 +1,8 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
from ..models.sd_unet import PushBlock, PopBlock
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
@@ -10,97 +10,6 @@ from PIL import Image
import numpy as np
def lets_dance(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
device = "cuda",
vram_limit_level = 0,
):
# 1. ControlNet
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
# I leave it here because I intend to do something interesting on the ControlNets.
controlnet_insert_block_id = 30
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_]
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
# 4. blocks
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
motion_module_id = motion_modules.call_block_id[block_id]
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
if vram_limit_level>=1:
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
else:
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)
hidden_states = unet.conv_act(hidden_states)
hidden_states = unet.conv_out(hidden_states)
return hidden_states
def lets_dance_with_long_video(
unet: SDUNet,
motion_modules: SDMotionModel = None,
@@ -187,6 +96,11 @@ class SDVideoPipeline(torch.nn.Module):
self.motion_modules = model_manager.motion_modules
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
pipe = SDVideoPipeline(
@@ -196,6 +110,7 @@ class SDVideoPipeline(torch.nn.Module):
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe
@@ -248,12 +163,6 @@ class SDVideoPipeline(torch.nn.Module):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device).cpu()
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device).cpu()
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
@@ -265,6 +174,12 @@ class SDVideoPipeline(torch.nn.Module):
latents = self.encode_images(input_frames)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
# Prepare ControlNets
if controlnet_frames is not None:
controlnet_frames = torch.stack([

View File

@@ -1,4 +1,5 @@
from ..models import ModelManager
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder
# TODO: SDXL ControlNet
from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
@@ -7,29 +8,77 @@ from PIL import Image
import numpy as np
class SDXLPipeline(torch.nn.Module):
class SDXLImagePipeline(torch.nn.Module):
def __init__(self):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDXLPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# TODO: SDXL ControlNet
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.text_encoder_2 = model_manager.text_encoder_2
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
# load textual inversion
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
# TODO: SDXL ControlNet
pass
def fetch_beautiful_prompt(self, model_manager: ModelManager):
if "beautiful_prompt" in model_manager.model:
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
pipe = SDXLImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_beautiful_prompt(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
model_manager: ModelManager,
prompter: SDXLPrompter,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
init_image=None,
input_image=None,
controlnet_image=None,
denoising_strength=1.0,
refining_strength=0.0,
height=1024,
width=1024,
num_inference_steps=20,
@@ -39,76 +88,62 @@ class SDXLPipeline(torch.nn.Module):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Encode prompts
add_text_embeds, prompt_emb = prompter.encode_prompt(
model_manager.text_encoder,
model_manager.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=model_manager.device
)
if cfg_scale != 1.0:
negative_add_text_embeds, negative_prompt_emb = prompter.encode_prompt(
model_manager.text_encoder,
model_manager.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=model_manager.device
)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if init_image is not None:
image = self.preprocess_image(init_image).to(
device=model_manager.device, dtype=model_manager.torch_type
)
latents = model_manager.vae_encoder(
image.to(torch.float32),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise = torch.randn(
(1, 4, height//8, width//8),
device=model_manager.device, dtype=model_manager.torch_type
)
latents = self.scheduler.add_noise(
latents.to(model_manager.torch_type),
noise,
timestep=self.scheduler.timesteps[0]
)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device
)
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=model_manager.device)
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
if timestep >= 1000 * refining_strength:
denoising_model = model_manager.unet
else:
denoising_model = model_manager.refiner
if cfg_scale != 1.0:
noise_pred_cond = denoising_model(
latents, timestep, prompt_emb,
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
noise_pred_posi = self.unet(
latents, timestep, prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise_pred_uncond = denoising_model(
latents, timestep, negative_prompt_emb,
add_time_id=add_time_id, add_text_embeds=negative_add_text_embeds,
noise_pred_nega = self.unet(
latents, timestep, prompt_emb_nega,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = denoising_model(
latents, timestep, prompt_emb,
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
noise_pred = self.unet(
latents, timestep, prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
@@ -118,9 +153,6 @@ class SDXLPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = latents.to(torch.float32)
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -1,5 +1,5 @@
from transformers import CLIPTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
import torch, os
@@ -35,14 +35,30 @@ def tokenize_long_prompt(tokenizer, prompt):
return input_ids
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
class BeautifulPrompt:
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
def __call__(self, raw_prompt):
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return prompt
class SDPrompter:
@@ -50,11 +66,20 @@ class SDPrompter:
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.keyword_dict = {}
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
@@ -70,6 +95,16 @@ class SDPrompter:
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
class SDXLPrompter:
def __init__(
@@ -80,6 +115,8 @@ class SDXLPrompter:
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.keyword_dict = {}
self.beautiful_prompt: BeautifulPrompt = None
def encode_prompt(
self,
@@ -88,8 +125,19 @@ class SDXLPrompter:
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
# Textual Inversion
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
# Beautiful Prompt
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
@@ -105,3 +153,22 @@ class SDXLPrompter:
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""

View File

@@ -1,37 +0,0 @@
from transformers import CLIPTokenizer
class SDTokenizer:
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
# Get model_max_length from self.tokenizer
length = self.tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
self.tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore self.tokenizer.model_max_length
self.tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids

View File

@@ -1,45 +0,0 @@
from transformers import CLIPTokenizer
from .sd_tokenizer import SDTokenizer
class SDXLTokenizer(SDTokenizer):
def __init__(self):
super().__init__()
class SDXLTokenizer2:
def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"):
# We use the tokenizer implemented by transformers
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def __call__(self, prompt):
# Get model_max_length from self.tokenizer
length = self.tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
self.tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore self.tokenizer.model_max_length
self.tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids

View File

@@ -18,7 +18,7 @@ class EnhancedDDIMScheduler():
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.
max_timestep = round(self.num_train_timesteps * denoising_strength) - 1
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = [max_timestep]