mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
compatibility update
This commit is contained in:
@@ -36,10 +36,17 @@ class MultiControlNetManager:
|
||||
], dim=0)
|
||||
return processed_image
|
||||
|
||||
def __call__(self, sample, timestep, encoder_hidden_states, conditionings):
|
||||
def __call__(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditionings,
|
||||
tiled=False, tile_size=64, tile_stride=32
|
||||
):
|
||||
res_stack = None
|
||||
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
||||
res_stack_ = model(sample, timestep, encoder_hidden_states, conditioning)
|
||||
res_stack_ = model(
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
|
||||
@@ -12,7 +12,7 @@ Processor_id: TypeAlias = Literal[
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=512):
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
@@ -44,7 +44,8 @@ class Annotator:
|
||||
else:
|
||||
kwargs = {}
|
||||
if self.processor is not None:
|
||||
image = self.processor(image, detect_resolution=self.detect_resolution, image_resolution=min(width, height), **kwargs)
|
||||
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
||||
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
||||
image = image.resize((width, height))
|
||||
return image
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -24,12 +26,19 @@ class ModelManager:
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
return True
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
return False
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
@@ -74,7 +83,6 @@ class ModelManager:
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
@@ -109,6 +117,15 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
).to(self.device).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -144,6 +161,8 @@ class ModelManager:
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -41,7 +41,7 @@ class Attention(torch.nn.Module):
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class ControlNetConditioningLayer(torch.nn.Module):
|
||||
@@ -92,20 +93,37 @@ class SDControlNet(torch.nn.Module):
|
||||
|
||||
self.global_pool = global_pool
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, conditioning):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if tiled and not isinstance(block, PushBlock):
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. ControlNet blocks
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
|
||||
@@ -279,31 +279,19 @@ class SDUNet(torch.nn.Module):
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, additional_res_stack=None, **kwargs):
|
||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if additional_res_stack is not None and i==31:
|
||||
hidden_states += additional_res_stack.pop()
|
||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
additional_res_stack = None
|
||||
if tiled:
|
||||
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
|
||||
block, hidden_states, time_emb, text_emb, res_stack,
|
||||
height, width, tile_size, tile_stride
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -312,23 +300,6 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
|
||||
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
|
||||
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
|
||||
hidden_states = Tiler()(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
inter_device=hidden_states.device,
|
||||
inter_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDUNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import Tiler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
@@ -79,11 +79,13 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_decoder import VAEAttentionBlock
|
||||
from .tiler import Tiler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class SDVAEEncoder(torch.nn.Module):
|
||||
@@ -38,11 +38,13 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class Tiler(torch.nn.Module):
|
||||
@@ -70,6 +71,106 @@ class Tiler(torch.nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class TileWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def mask(self, height, width, border_width):
|
||||
# Create a mask with shape (height, width).
|
||||
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / border_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
|
||||
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||
batch_size, channel, _, _ = model_input.shape
|
||||
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||
unfold_operator = torch.nn.Unfold(
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
model_input = unfold_operator(model_input)
|
||||
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||
|
||||
return model_input
|
||||
|
||||
|
||||
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||
# Call y=forward_fn(x) for each tile
|
||||
tile_num = model_input.shape[-1]
|
||||
model_output_stack = []
|
||||
|
||||
for tile_id in range(0, tile_num, tile_batch_size):
|
||||
|
||||
# process input
|
||||
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||
|
||||
# process output
|
||||
y = forward_fn(x)
|
||||
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||
model_output_stack.append(y)
|
||||
|
||||
model_output = torch.concat(model_output_stack, dim=-1)
|
||||
return model_output
|
||||
|
||||
|
||||
def io_scale(self, model_output, tile_size):
|
||||
# Determine the size modification happend in forward_fn
|
||||
# We only consider the same scale on height and width.
|
||||
io_scale = model_output.shape[2] / tile_size
|
||||
return io_scale
|
||||
|
||||
|
||||
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||
# The reversed function of tile
|
||||
mask = self.mask(tile_size, tile_size, border_width)
|
||||
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||
model_output = model_output * mask
|
||||
|
||||
fold_operator = torch.nn.Fold(
|
||||
output_size=(height, width),
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||
height, width = model_input.shape[2], model_input.shape[3]
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
|
||||
# tile
|
||||
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||
|
||||
# inference
|
||||
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||
|
||||
# resize
|
||||
io_scale = self.io_scale(model_output, tile_size)
|
||||
height, width = int(height*io_scale), int(width*io_scale)
|
||||
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||
border_width = int(border_width*io_scale)
|
||||
|
||||
# untile
|
||||
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||
|
||||
# Done!
|
||||
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||
return model_output
|
||||
@@ -1,3 +1,3 @@
|
||||
from .stable_diffusion import SDPipeline
|
||||
from .stable_diffusion_xl import SDXLPipeline
|
||||
from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline
|
||||
|
||||
113
diffsynth/pipelines/dancer.py
Normal file
113
diffsynth/pipelines/dancer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
from ..models import SDUNet, SDMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock
|
||||
from ..models.tiler import TileWorker
|
||||
from ..controlnets import MultiControlNetManager
|
||||
|
||||
|
||||
def lets_dance(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 1. ControlNet
|
||||
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
||||
# I leave it here because I intend to do something interesting on the ControlNets.
|
||||
controlnet_insert_block_id = 30
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
res_stacks = []
|
||||
# process controlnet frames with batch
|
||||
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
|
||||
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
|
||||
res_stack = controlnet(
|
||||
sample[batch_id: batch_id_],
|
||||
timestep,
|
||||
encoder_hidden_states[batch_id: batch_id_],
|
||||
controlnet_frames[:, batch_id: batch_id_],
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if vram_limit_level >= 1:
|
||||
res_stack = [res.cpu() for res in res_stack]
|
||||
res_stacks.append(res_stack)
|
||||
# concat the residual
|
||||
additional_res_stack = []
|
||||
for i in range(len(res_stacks[0])):
|
||||
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
|
||||
additional_res_stack.append(res)
|
||||
else:
|
||||
additional_res_stack = None
|
||||
|
||||
# 2. time
|
||||
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = unet.time_embedding(time_emb)
|
||||
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
for block_id, block in enumerate(unet.blocks):
|
||||
# 4.1 UNet
|
||||
if isinstance(block, PushBlock):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].cpu()
|
||||
elif isinstance(block, PopBlock):
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].to(device)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
else:
|
||||
hidden_states_input = hidden_states
|
||||
hidden_states_output = []
|
||||
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
||||
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
||||
if tiled:
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb[batch_id: batch_id_], res_stack)[0],
|
||||
hidden_states_input[batch_id: batch_id_],
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack)
|
||||
hidden_states_output.append(hidden_states)
|
||||
hidden_states = torch.concat(hidden_states_output, dim=0)
|
||||
# 4.2 AnimateDiff
|
||||
if motion_modules is not None:
|
||||
if block_id in motion_modules.call_block_id:
|
||||
motion_module_id = motion_modules.call_block_id[block_id]
|
||||
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
batch_size=1
|
||||
)
|
||||
# 4.3 ControlNet
|
||||
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
|
||||
hidden_states += additional_res_stack.pop().to(device)
|
||||
if vram_limit_level>=1:
|
||||
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
else:
|
||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
|
||||
# 5. output
|
||||
hidden_states = unet.conv_norm_out(hidden_states)
|
||||
hidden_states = unet.conv_act(hidden_states)
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -1,14 +1,16 @@
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder
|
||||
from ..controlnets.controlnet_unit import MultiControlNetManager
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDPipeline(torch.nn.Module):
|
||||
class SDImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
@@ -23,6 +25,7 @@ class SDPipeline(torch.nn.Module):
|
||||
self.vae_encoder: SDVAEEncoder = None
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.unet = model_manager.unet
|
||||
@@ -31,13 +34,48 @@ class SDPipeline(torch.nn.Module):
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
def fetch_controlnet_models(self, controlnet_units=[]):
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id),
|
||||
model_manager.get_model_with_model_path(config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
pipe = SDImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -45,7 +83,7 @@ class SDPipeline(torch.nn.Module):
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
init_image=None,
|
||||
input_image=None,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
@@ -57,48 +95,43 @@ class SDPipeline(torch.nn.Module):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device)
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if init_image is not None:
|
||||
image = self.preprocess_image(init_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# ControlNet
|
||||
if controlnet_image is not None:
|
||||
additional_res_stack_posi = self.controlnet(latents, timestep, prompt_emb_posi, controlnet_image)
|
||||
additional_res_stack_nega = self.controlnet(latents, timestep, prompt_emb_nega, controlnet_image)
|
||||
else:
|
||||
additional_res_stack_posi = None
|
||||
additional_res_stack_nega = None
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = self.unet(
|
||||
latents, timestep, prompt_emb_posi,
|
||||
additional_res_stack=additional_res_stack_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
noise_pred_posi = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
device=self.device, vram_limit_level=0
|
||||
)
|
||||
noise_pred_nega = self.unet(
|
||||
latents, timestep, prompt_emb_nega,
|
||||
additional_res_stack=additional_res_stack_nega,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
noise_pred_nega = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
device=self.device, vram_limit_level=0
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
|
||||
@@ -110,8 +143,6 @@ class SDPipeline(torch.nn.Module):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@@ -10,97 +10,6 @@ from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def lets_dance(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 1. ControlNet
|
||||
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
||||
# I leave it here because I intend to do something interesting on the ControlNets.
|
||||
controlnet_insert_block_id = 30
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
res_stacks = []
|
||||
# process controlnet frames with batch
|
||||
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
|
||||
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
|
||||
res_stack = controlnet(
|
||||
sample[batch_id: batch_id_],
|
||||
timestep,
|
||||
encoder_hidden_states[batch_id: batch_id_],
|
||||
controlnet_frames[:, batch_id: batch_id_]
|
||||
)
|
||||
if vram_limit_level >= 1:
|
||||
res_stack = [res.cpu() for res in res_stack]
|
||||
res_stacks.append(res_stack)
|
||||
# concat the residual
|
||||
additional_res_stack = []
|
||||
for i in range(len(res_stacks[0])):
|
||||
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
|
||||
additional_res_stack.append(res)
|
||||
else:
|
||||
additional_res_stack = None
|
||||
|
||||
# 2. time
|
||||
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = unet.time_embedding(time_emb)
|
||||
|
||||
# 3. pre-process
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
for block_id, block in enumerate(unet.blocks):
|
||||
# 4.1 UNet
|
||||
if isinstance(block, PushBlock):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].cpu()
|
||||
elif isinstance(block, PopBlock):
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].to(device)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
else:
|
||||
hidden_states_input = hidden_states
|
||||
hidden_states_output = []
|
||||
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
||||
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
||||
hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack)
|
||||
hidden_states_output.append(hidden_states)
|
||||
hidden_states = torch.concat(hidden_states_output, dim=0)
|
||||
# 4.2 AnimateDiff
|
||||
if motion_modules is not None:
|
||||
if block_id in motion_modules.call_block_id:
|
||||
motion_module_id = motion_modules.call_block_id[block_id]
|
||||
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
batch_size=1
|
||||
)
|
||||
# 4.3 ControlNet
|
||||
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
|
||||
hidden_states += additional_res_stack.pop().to(device)
|
||||
if vram_limit_level>=1:
|
||||
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
else:
|
||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
|
||||
# 5. output
|
||||
hidden_states = unet.conv_norm_out(hidden_states)
|
||||
hidden_states = unet.conv_act(hidden_states)
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def lets_dance_with_long_video(
|
||||
unet: SDUNet,
|
||||
motion_modules: SDMotionModel = None,
|
||||
@@ -187,6 +96,11 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
self.motion_modules = model_manager.motion_modules
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
||||
pipe = SDVideoPipeline(
|
||||
@@ -196,6 +110,7 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
@@ -248,12 +163,6 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device).cpu()
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device).cpu()
|
||||
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
|
||||
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
@@ -265,6 +174,12 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
latents = self.encode_images(input_frames)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
|
||||
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
|
||||
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
|
||||
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
controlnet_frames = torch.stack([
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from ..models import ModelManager
|
||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder
|
||||
# TODO: SDXL ControlNet
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
@@ -7,29 +8,77 @@ from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLPipeline(torch.nn.Module):
|
||||
class SDXLImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler()
|
||||
self.prompter = SDXLPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# TODO: SDXL ControlNet
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.text_encoder_2 = model_manager.text_encoder_2
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
# TODO: SDXL ControlNet
|
||||
pass
|
||||
|
||||
|
||||
def fetch_beautiful_prompt(self, model_manager: ModelManager):
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.prompter.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||
pipe = SDXLImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_beautiful_prompt(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
prompter: SDXLPrompter,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
init_image=None,
|
||||
input_image=None,
|
||||
controlnet_image=None,
|
||||
denoising_strength=1.0,
|
||||
refining_strength=0.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
@@ -39,76 +88,62 @@ class SDXLPipeline(torch.nn.Module):
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Encode prompts
|
||||
add_text_embeds, prompt_emb = prompter.encode_prompt(
|
||||
model_manager.text_encoder,
|
||||
model_manager.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=model_manager.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
negative_add_text_embeds, negative_prompt_emb = prompter.encode_prompt(
|
||||
model_manager.text_encoder,
|
||||
model_manager.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=model_manager.device
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if init_image is not None:
|
||||
image = self.preprocess_image(init_image).to(
|
||||
device=model_manager.device, dtype=model_manager.torch_type
|
||||
)
|
||||
latents = model_manager.vae_encoder(
|
||||
image.to(torch.float32),
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise = torch.randn(
|
||||
(1, 4, height//8, width//8),
|
||||
device=model_manager.device, dtype=model_manager.torch_type
|
||||
)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents.to(model_manager.torch_type),
|
||||
noise,
|
||||
timestep=self.scheduler.timesteps[0]
|
||||
)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=model_manager.device, dtype=model_manager.torch_type)
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=model_manager.device)
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(model_manager.device)
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
if timestep >= 1000 * refining_strength:
|
||||
denoising_model = model_manager.unet
|
||||
else:
|
||||
denoising_model = model_manager.refiner
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_cond = denoising_model(
|
||||
latents, timestep, prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
|
||||
noise_pred_posi = self.unet(
|
||||
latents, timestep, prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise_pred_uncond = denoising_model(
|
||||
latents, timestep, negative_prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=negative_add_text_embeds,
|
||||
noise_pred_nega = self.unet(
|
||||
latents, timestep, prompt_emb_nega,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = denoising_model(
|
||||
latents, timestep, prompt_emb,
|
||||
add_time_id=add_time_id, add_text_embeds=add_text_embeds,
|
||||
noise_pred = self.unet(
|
||||
latents, timestep, prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
@@ -118,9 +153,6 @@ class SDXLPipeline(torch.nn.Module):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
latents = latents.to(torch.float32)
|
||||
image = model_manager.vae_decoder(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, load_state_dict
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2
|
||||
import torch, os
|
||||
|
||||
|
||||
@@ -35,14 +35,30 @@ def tokenize_long_prompt(tokenizer, prompt):
|
||||
return input_ids
|
||||
|
||||
|
||||
def search_for_embeddings(state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class SDPrompter:
|
||||
@@ -50,11 +66,20 @@ class SDPrompter:
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.keyword_dict = {}
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda"):
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
# Textual Inversion
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
|
||||
# Beautiful Prompt
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
@@ -70,6 +95,16 @@ class SDPrompter:
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
|
||||
class SDXLPrompter:
|
||||
def __init__(
|
||||
@@ -80,6 +115,8 @@ class SDXLPrompter:
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
self.keyword_dict = {}
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -88,8 +125,19 @@ class SDXLPrompter:
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
# Textual Inversion
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
|
||||
# Beautiful Prompt
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
@@ -105,3 +153,22 @@ class SDXLPrompter:
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def __call__(self, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = self.tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
self.tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore self.tokenizer.model_max_length
|
||||
self.tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
@@ -1,45 +0,0 @@
|
||||
from transformers import CLIPTokenizer
|
||||
from .sd_tokenizer import SDTokenizer
|
||||
|
||||
|
||||
class SDXLTokenizer(SDTokenizer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class SDXLTokenizer2:
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion_xl/tokenizer_2"):
|
||||
# We use the tokenizer implemented by transformers
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def __call__(self, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = self.tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
self.tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore self.tokenizer.model_max_length
|
||||
self.tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
@@ -18,7 +18,7 @@ class EnhancedDDIMScheduler():
|
||||
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
|
||||
# The timesteps are aligned to 999...0, which is different from other implementations,
|
||||
# but I think this implementation is more reasonable in theory.
|
||||
max_timestep = round(self.num_train_timesteps * denoising_strength) - 1
|
||||
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||
if num_inference_steps == 1:
|
||||
self.timesteps = [max_timestep]
|
||||
|
||||
Reference in New Issue
Block a user