release ExVideo

This commit is contained in:
Artiprocher
2024-06-20 16:01:51 +08:00
parent b78ec9e086
commit 6e25864a3d
6 changed files with 259 additions and 50 deletions

View File

@@ -110,7 +110,11 @@ class ModelManager:
param_name = "quant_conv.weight"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
def is_ExVideo_StableVideoDiffusion(self, state_dict):
param_name = "blocks.185.positional_embedding.embeddings"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = {
"image_encoder": SVDImageEncoder,
"unet": SVDUNet,
@@ -120,8 +124,12 @@ class ModelManager:
if components is None:
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
if component == "unet":
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
@@ -305,6 +313,16 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
unet_state_dict = self.model["unet"].state_dict()
self.model["unet"].to("cpu")
del self.model["unet"]
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
self.model["unet"].load_state_dict(state_dict, strict=False)
self.model["unet"].to(self.torch_dtype).to(self.device)
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -370,6 +388,8 @@ class ModelManager:
self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path)
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:

View File

@@ -95,19 +95,59 @@ class TemporalTimesteps(torch.nn.Module):
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class TrainableTemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, num_frames: int):
super().__init__()
timesteps = PositionalID()(num_frames)
embeddings = get_timestep_embedding(timesteps, num_channels, flip_sin_to_cos, downscale_freq_shift)
self.embeddings = torch.nn.Parameter(embeddings)
def forward(self, timesteps):
t_emb = self.embeddings[timesteps]
return t_emb
class PositionalID(torch.nn.Module):
def __init__(self, max_id=25, repeat_length=20):
super().__init__()
self.max_id = max_id
self.repeat_length = repeat_length
def frame_id_to_position_id(self, frame_id):
if frame_id < self.max_id:
position_id = frame_id
else:
position_id = (frame_id - self.max_id) % (self.repeat_length * 2)
if position_id < self.repeat_length:
position_id = self.max_id - 2 - position_id
else:
position_id = self.max_id - 2 * self.repeat_length + position_id
return position_id
def forward(self, num_frames, pivot_frame_id=0):
position_ids = [self.frame_id_to_position_id(abs(i-pivot_frame_id)) for i in range(num_frames)]
position_ids = torch.IntTensor(position_ids)
return position_ids
class TemporalAttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None, add_positional_conv=None):
super().__init__()
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
self.positional_embedding_proj = torch.nn.Sequential(
torch.nn.Linear(in_channels, in_channels * 4),
torch.nn.SiLU(),
torch.nn.Linear(in_channels * 4, in_channels)
)
if add_positional_conv is not None:
self.positional_embedding = TrainableTemporalTimesteps(in_channels, True, 0, add_positional_conv)
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1, padding_mode="reflect")
else:
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
self.positional_conv = None
self.norm_in = torch.nn.LayerNorm(in_channels)
self.act_fn_in = GEGLU(in_channels, in_channels * 4)
@@ -137,14 +177,14 @@ class TemporalAttentionBlock(torch.nn.Module):
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
batch, inner_dim, height, width = hidden_states.shape
pos_emb = torch.arange(batch, dtype=hidden_states.dtype)
if batch > 25:
pos_emb *= 25 / batch
pos_emb = torch.arange(batch)
pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device)
pos_emb = self.positional_embedding_proj(pos_emb)[None, :, :]
hidden_states = hidden_states.permute(2, 3, 0, 1).reshape(height * width, batch, inner_dim)
hidden_states = hidden_states + pos_emb
pos_emb = self.positional_embedding_proj(pos_emb)
hidden_states = rearrange(hidden_states, "T C H W -> 1 C T H W") + rearrange(pos_emb, "T C -> 1 C T 1 1")
if self.positional_conv is not None:
hidden_states = self.positional_conv(hidden_states)
hidden_states = rearrange(hidden_states[0], "C T H W -> (H W) T C")
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
@@ -193,7 +233,7 @@ class PopMixBlock(torch.nn.Module):
class SVDUNet(torch.nn.Module):
def __init__(self):
def __init__(self, add_positional_conv=None):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
@@ -211,29 +251,29 @@ class SVDUNet(torch.nn.Module):
self.blocks = torch.nn.ModuleList([
# CrossAttnDownBlockSpatioTemporal
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(),
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(),
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
DownSampler(320), PushBlock(),
# CrossAttnDownBlockSpatioTemporal
ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(),
ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(),
ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
DownSampler(640), PushBlock(),
# CrossAttnDownBlockSpatioTemporal
ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(),
ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(),
ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
DownSampler(1280), PushBlock(),
# DownBlockSpatioTemporal
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
# UNetMidBlockSpatioTemporal
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
# UpBlockSpatioTemporal
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
@@ -241,28 +281,28 @@ class SVDUNet(torch.nn.Module):
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
UpSampler(1280),
# CrossAttnUpBlockSpatioTemporal
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
UpSampler(1280),
# CrossAttnUpBlockSpatioTemporal
PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
UpSampler(640),
# CrossAttnUpBlockSpatioTemporal
PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
])
self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True)
@@ -327,7 +367,7 @@ class SVDUNet(torch.nn.Module):
return values
def forward(self, sample, timestep, encoder_hidden_states, add_time_id, **kwargs):
def forward(self, sample, timestep, encoder_hidden_states, add_time_id, use_gradient_checkpointing=False, **kwargs):
# 1. time
timestep = torch.tensor((timestep,)).to(sample.device)
t_emb = self.time_proj(timestep).to(sample.dtype)
@@ -346,8 +386,19 @@ class SVDUNet(torch.nn.Module):
res_stack = [hidden_states]
# 3. blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock) or isinstance(block, PopMixBlock)):
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 4. output
hidden_states = self.conv_norm_out(hidden_states)
@@ -497,7 +548,7 @@ class SVDUNetStateDictConverter:
return state_dict_
def from_civitai(self, state_dict):
def from_civitai(self, state_dict, add_positional_conv=None):
rename_dict = {
"model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
"model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
@@ -1935,4 +1986,18 @@ class SVDUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
if add_positional_conv is not None:
extra_names = [
"blocks.7.positional_conv", "blocks.17.positional_conv", "blocks.29.positional_conv", "blocks.39.positional_conv",
"blocks.51.positional_conv", "blocks.61.positional_conv", "blocks.83.positional_conv", "blocks.113.positional_conv",
"blocks.123.positional_conv", "blocks.133.positional_conv", "blocks.144.positional_conv", "blocks.154.positional_conv",
"blocks.164.positional_conv", "blocks.175.positional_conv", "blocks.185.positional_conv", "blocks.195.positional_conv",
]
extra_channels = [320, 320, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320]
for name, channels in zip(extra_names, extra_channels):
weight = torch.zeros((channels, channels, 3, 3, 3))
weight[:,:,1,1,1] = torch.eye(channels, channels)
bias = torch.zeros((channels,))
state_dict_[name + ".weight"] = weight
state_dict_[name + ".bias"] = bias
return state_dict_

View File

@@ -107,6 +107,14 @@ class SVDVideoPipeline(torch.nn.Module):
noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
return noise_pred
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
if post_normalize:
mean, std = latents.mean(), latents.std()
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
latents = latents * contrast_enhance_scale
return latents
@torch.no_grad()
@@ -126,6 +134,8 @@ class SVDVideoPipeline(torch.nn.Module):
motion_bucket_id=127,
noise_aug_strength=0.02,
num_inference_steps=20,
post_normalize=True,
contrast_enhance_scale=1.2,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -178,6 +188,7 @@ class SVDVideoPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
video = self.tensor2video(video)

View File

@@ -43,3 +43,17 @@ class ContinuousODEScheduler():
sigma = self.sigmas[timestep_id]
sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
return sample
def training_target(self, sample, noise, timestep):
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
weight = (1 + sigma*sigma).sqrt() / sigma
return weight