mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
release ExVideo
This commit is contained in:
@@ -110,7 +110,11 @@ class ModelManager:
|
||||
param_name = "quant_conv.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
||||
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
||||
param_name = "blocks.185.positional_embedding.embeddings"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
"unet": SVDUNet,
|
||||
@@ -120,8 +124,12 @@ class ModelManager:
|
||||
if components is None:
|
||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||
for component in components:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
if component == "unet":
|
||||
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
||||
else:
|
||||
self.model[component] = component_dict[component]()
|
||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||
self.model[component].to(self.torch_dtype).to(self.device)
|
||||
self.model_path[component] = file_path
|
||||
|
||||
@@ -305,6 +313,16 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
||||
unet_state_dict = self.model["unet"].state_dict()
|
||||
self.model["unet"].to("cpu")
|
||||
del self.model["unet"]
|
||||
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
||||
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
||||
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
||||
self.model["unet"].load_state_dict(state_dict, strict=False)
|
||||
self.model["unet"].to(self.torch_dtype).to(self.device)
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -370,6 +388,8 @@ class ModelManager:
|
||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||
elif self.is_diffusers_vae(state_dict):
|
||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
||||
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -95,19 +95,59 @@ class TemporalTimesteps(torch.nn.Module):
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TrainableTemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, num_frames: int):
|
||||
super().__init__()
|
||||
timesteps = PositionalID()(num_frames)
|
||||
embeddings = get_timestep_embedding(timesteps, num_channels, flip_sin_to_cos, downscale_freq_shift)
|
||||
self.embeddings = torch.nn.Parameter(embeddings)
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = self.embeddings[timesteps]
|
||||
return t_emb
|
||||
|
||||
|
||||
class PositionalID(torch.nn.Module):
|
||||
def __init__(self, max_id=25, repeat_length=20):
|
||||
super().__init__()
|
||||
self.max_id = max_id
|
||||
self.repeat_length = repeat_length
|
||||
|
||||
def frame_id_to_position_id(self, frame_id):
|
||||
if frame_id < self.max_id:
|
||||
position_id = frame_id
|
||||
else:
|
||||
position_id = (frame_id - self.max_id) % (self.repeat_length * 2)
|
||||
if position_id < self.repeat_length:
|
||||
position_id = self.max_id - 2 - position_id
|
||||
else:
|
||||
position_id = self.max_id - 2 * self.repeat_length + position_id
|
||||
return position_id
|
||||
|
||||
def forward(self, num_frames, pivot_frame_id=0):
|
||||
position_ids = [self.frame_id_to_position_id(abs(i-pivot_frame_id)) for i in range(num_frames)]
|
||||
position_ids = torch.IntTensor(position_ids)
|
||||
return position_ids
|
||||
|
||||
|
||||
class TemporalAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None, add_positional_conv=None):
|
||||
super().__init__()
|
||||
|
||||
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
|
||||
self.positional_embedding_proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_channels, in_channels * 4),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(in_channels * 4, in_channels)
|
||||
)
|
||||
if add_positional_conv is not None:
|
||||
self.positional_embedding = TrainableTemporalTimesteps(in_channels, True, 0, add_positional_conv)
|
||||
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1, padding_mode="reflect")
|
||||
else:
|
||||
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
|
||||
self.positional_conv = None
|
||||
|
||||
self.norm_in = torch.nn.LayerNorm(in_channels)
|
||||
self.act_fn_in = GEGLU(in_channels, in_channels * 4)
|
||||
@@ -137,14 +177,14 @@ class TemporalAttentionBlock(torch.nn.Module):
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
|
||||
batch, inner_dim, height, width = hidden_states.shape
|
||||
pos_emb = torch.arange(batch, dtype=hidden_states.dtype)
|
||||
if batch > 25:
|
||||
pos_emb *= 25 / batch
|
||||
pos_emb = torch.arange(batch)
|
||||
pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
pos_emb = self.positional_embedding_proj(pos_emb)[None, :, :]
|
||||
|
||||
hidden_states = hidden_states.permute(2, 3, 0, 1).reshape(height * width, batch, inner_dim)
|
||||
hidden_states = hidden_states + pos_emb
|
||||
pos_emb = self.positional_embedding_proj(pos_emb)
|
||||
|
||||
hidden_states = rearrange(hidden_states, "T C H W -> 1 C T H W") + rearrange(pos_emb, "T C -> 1 C T 1 1")
|
||||
if self.positional_conv is not None:
|
||||
hidden_states = self.positional_conv(hidden_states)
|
||||
hidden_states = rearrange(hidden_states[0], "C T H W -> (H W) T C")
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
@@ -193,7 +233,7 @@ class PopMixBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class SVDUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, add_positional_conv=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
@@ -211,29 +251,29 @@ class SVDUNet(torch.nn.Module):
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# CrossAttnDownBlockSpatioTemporal
|
||||
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(),
|
||||
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(),
|
||||
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
|
||||
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
|
||||
DownSampler(320), PushBlock(),
|
||||
# CrossAttnDownBlockSpatioTemporal
|
||||
ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(),
|
||||
ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(),
|
||||
ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
|
||||
ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
|
||||
DownSampler(640), PushBlock(),
|
||||
# CrossAttnDownBlockSpatioTemporal
|
||||
ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(),
|
||||
ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
|
||||
DownSampler(1280), PushBlock(),
|
||||
# DownBlockSpatioTemporal
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||
# UNetMidBlockSpatioTemporal
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
||||
# UpBlockSpatioTemporal
|
||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
||||
@@ -241,28 +281,28 @@ class SVDUNet(torch.nn.Module):
|
||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
||||
UpSampler(1280),
|
||||
# CrossAttnUpBlockSpatioTemporal
|
||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
||||
PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||
PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||
UpSampler(1280),
|
||||
# CrossAttnUpBlockSpatioTemporal
|
||||
PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
|
||||
PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
|
||||
PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
|
||||
PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
|
||||
PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
|
||||
PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
|
||||
UpSampler(640),
|
||||
# CrossAttnUpBlockSpatioTemporal
|
||||
PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
|
||||
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
|
||||
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
|
||||
PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
|
||||
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
|
||||
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True)
|
||||
@@ -327,7 +367,7 @@ class SVDUNet(torch.nn.Module):
|
||||
return values
|
||||
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, add_time_id, **kwargs):
|
||||
def forward(self, sample, timestep, encoder_hidden_states, add_time_id, use_gradient_checkpointing=False, **kwargs):
|
||||
# 1. time
|
||||
timestep = torch.tensor((timestep,)).to(sample.device)
|
||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
@@ -346,8 +386,19 @@ class SVDUNet(torch.nn.Module):
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock) or isinstance(block, PopMixBlock)):
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -497,7 +548,7 @@ class SVDUNetStateDictConverter:
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
def from_civitai(self, state_dict, add_positional_conv=None):
|
||||
rename_dict = {
|
||||
"model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
|
||||
"model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
|
||||
@@ -1935,4 +1986,18 @@ class SVDUNetStateDictConverter:
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
if add_positional_conv is not None:
|
||||
extra_names = [
|
||||
"blocks.7.positional_conv", "blocks.17.positional_conv", "blocks.29.positional_conv", "blocks.39.positional_conv",
|
||||
"blocks.51.positional_conv", "blocks.61.positional_conv", "blocks.83.positional_conv", "blocks.113.positional_conv",
|
||||
"blocks.123.positional_conv", "blocks.133.positional_conv", "blocks.144.positional_conv", "blocks.154.positional_conv",
|
||||
"blocks.164.positional_conv", "blocks.175.positional_conv", "blocks.185.positional_conv", "blocks.195.positional_conv",
|
||||
]
|
||||
extra_channels = [320, 320, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320]
|
||||
for name, channels in zip(extra_names, extra_channels):
|
||||
weight = torch.zeros((channels, channels, 3, 3, 3))
|
||||
weight[:,:,1,1,1] = torch.eye(channels, channels)
|
||||
bias = torch.zeros((channels,))
|
||||
state_dict_[name + ".weight"] = weight
|
||||
state_dict_[name + ".bias"] = bias
|
||||
return state_dict_
|
||||
|
||||
Reference in New Issue
Block a user