mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
compatibility update
This commit is contained in:
@@ -15,6 +15,8 @@ from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -24,12 +26,19 @@ class ModelManager:
|
||||
self.model_path = {}
|
||||
self.textual_inversion_dict = {}
|
||||
|
||||
def is_beautiful_prompt(self, state_dict):
|
||||
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stabe_diffusion_xl(self, state_dict):
|
||||
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_stable_diffusion(self, state_dict):
|
||||
return True
|
||||
if self.is_stabe_diffusion_xl(state_dict):
|
||||
return False
|
||||
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
@@ -74,7 +83,6 @@ class ModelManager:
|
||||
"unet": SDXLUNet,
|
||||
"vae_decoder": SDXLVAEDecoder,
|
||||
"vae_encoder": SDXLVAEEncoder,
|
||||
"refiner": SDXLUNet,
|
||||
}
|
||||
if components is None:
|
||||
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
||||
@@ -109,6 +117,15 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
model_folder = os.path.dirname(file_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
||||
).to(self.device).eval()
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -144,6 +161,8 @@ class ModelManager:
|
||||
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion(state_dict):
|
||||
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_beautiful_prompt(state_dict):
|
||||
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -41,7 +41,7 @@ class Attention(torch.nn.Module):
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class ControlNetConditioningLayer(torch.nn.Module):
|
||||
@@ -92,20 +93,37 @@ class SDControlNet(torch.nn.Module):
|
||||
|
||||
self.global_pool = global_pool
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, conditioning):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if tiled and not isinstance(block, PushBlock):
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. ControlNet blocks
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
|
||||
@@ -279,31 +279,19 @@ class SDUNet(torch.nn.Module):
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, tiled=False, tile_size=64, tile_stride=8, additional_res_stack=None, **kwargs):
|
||||
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if additional_res_stack is not None and i==31:
|
||||
hidden_states += additional_res_stack.pop()
|
||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
additional_res_stack = None
|
||||
if tiled:
|
||||
hidden_states, time_emb, text_emb, res_stack = self.tiled_inference(
|
||||
block, hidden_states, time_emb, text_emb, res_stack,
|
||||
height, width, tile_size, tile_stride
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -312,23 +300,6 @@ class SDUNet(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def tiled_inference(self, block, hidden_states, time_emb, text_emb, res_stack, height, width, tile_size, tile_stride):
|
||||
if block.__class__.__name__ in ["ResnetBlock", "AttentionBlock", "DownSampler", "UpSampler"]:
|
||||
batch_size, inter_channel, inter_height, inter_width = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
|
||||
hidden_states = Tiler()(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
inter_device=hidden_states.device,
|
||||
inter_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDUNetStateDictConverter()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import Tiler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
@@ -79,11 +79,13 @@ class SDVAEDecoder(torch.nn.Module):
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_decoder import VAEAttentionBlock
|
||||
from .tiler import Tiler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class SDVAEEncoder(torch.nn.Module):
|
||||
@@ -38,11 +38,13 @@ class SDVAEEncoder(torch.nn.Module):
|
||||
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = Tiler()(
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class Tiler(torch.nn.Module):
|
||||
@@ -70,6 +71,106 @@ class Tiler(torch.nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class TileWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def mask(self, height, width, border_width):
|
||||
# Create a mask with shape (height, width).
|
||||
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / border_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
|
||||
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||
batch_size, channel, _, _ = model_input.shape
|
||||
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||
unfold_operator = torch.nn.Unfold(
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
model_input = unfold_operator(model_input)
|
||||
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||
|
||||
return model_input
|
||||
|
||||
|
||||
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||
# Call y=forward_fn(x) for each tile
|
||||
tile_num = model_input.shape[-1]
|
||||
model_output_stack = []
|
||||
|
||||
for tile_id in range(0, tile_num, tile_batch_size):
|
||||
|
||||
# process input
|
||||
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||
|
||||
# process output
|
||||
y = forward_fn(x)
|
||||
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||
model_output_stack.append(y)
|
||||
|
||||
model_output = torch.concat(model_output_stack, dim=-1)
|
||||
return model_output
|
||||
|
||||
|
||||
def io_scale(self, model_output, tile_size):
|
||||
# Determine the size modification happend in forward_fn
|
||||
# We only consider the same scale on height and width.
|
||||
io_scale = model_output.shape[2] / tile_size
|
||||
return io_scale
|
||||
|
||||
|
||||
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||
# The reversed function of tile
|
||||
mask = self.mask(tile_size, tile_size, border_width)
|
||||
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||
model_output = model_output * mask
|
||||
|
||||
fold_operator = torch.nn.Fold(
|
||||
output_size=(height, width),
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||
height, width = model_input.shape[2], model_input.shape[3]
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
|
||||
# tile
|
||||
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||
|
||||
# inference
|
||||
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||
|
||||
# resize
|
||||
io_scale = self.io_scale(model_output, tile_size)
|
||||
height, width = int(height*io_scale), int(width*io_scale)
|
||||
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||
border_width = int(border_width*io_scale)
|
||||
|
||||
# untile
|
||||
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||
|
||||
# Done!
|
||||
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||
return model_output
|
||||
Reference in New Issue
Block a user