release ExVideo

This commit is contained in:
Artiprocher
2024-06-20 16:01:51 +08:00
parent b78ec9e086
commit 6e25864a3d
6 changed files with 259 additions and 50 deletions

View File

@@ -110,7 +110,11 @@ class ModelManager:
param_name = "quant_conv.weight" param_name = "quant_conv.weight"
return param_name in state_dict return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""): def is_ExVideo_StableVideoDiffusion(self, state_dict):
param_name = "blocks.185.positional_embedding.embeddings"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = { component_dict = {
"image_encoder": SVDImageEncoder, "image_encoder": SVDImageEncoder,
"unet": SVDUNet, "unet": SVDUNet,
@@ -120,8 +124,12 @@ class ModelManager:
if components is None: if components is None:
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"] components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components: for component in components:
self.model[component] = component_dict[component]() if component == "unet":
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device) self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path self.model_path[component] = file_path
@@ -305,6 +313,16 @@ class ModelManager:
self.model[component] = model self.model[component] = model
self.model_path[component] = file_path self.model_path[component] = file_path
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
unet_state_dict = self.model["unet"].state_dict()
self.model["unet"].to("cpu")
del self.model["unet"]
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
self.model["unet"].load_state_dict(state_dict, strict=False)
self.model["unet"].to(self.torch_dtype).to(self.device)
def search_for_embeddings(self, state_dict): def search_for_embeddings(self, state_dict):
embeddings = [] embeddings = []
for k in state_dict: for k in state_dict:
@@ -370,6 +388,8 @@ class ModelManager:
self.load_hunyuan_dit(state_dict, file_path=file_path) self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict): elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path) self.load_diffusers_vae(state_dict, file_path=file_path)
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]): def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list: for file_path in file_path_list:

View File

@@ -97,17 +97,57 @@ class TemporalTimesteps(torch.nn.Module):
return t_emb return t_emb
class TrainableTemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, num_frames: int):
super().__init__()
timesteps = PositionalID()(num_frames)
embeddings = get_timestep_embedding(timesteps, num_channels, flip_sin_to_cos, downscale_freq_shift)
self.embeddings = torch.nn.Parameter(embeddings)
def forward(self, timesteps):
t_emb = self.embeddings[timesteps]
return t_emb
class PositionalID(torch.nn.Module):
def __init__(self, max_id=25, repeat_length=20):
super().__init__()
self.max_id = max_id
self.repeat_length = repeat_length
def frame_id_to_position_id(self, frame_id):
if frame_id < self.max_id:
position_id = frame_id
else:
position_id = (frame_id - self.max_id) % (self.repeat_length * 2)
if position_id < self.repeat_length:
position_id = self.max_id - 2 - position_id
else:
position_id = self.max_id - 2 * self.repeat_length + position_id
return position_id
def forward(self, num_frames, pivot_frame_id=0):
position_ids = [self.frame_id_to_position_id(abs(i-pivot_frame_id)) for i in range(num_frames)]
position_ids = torch.IntTensor(position_ids)
return position_ids
class TemporalAttentionBlock(torch.nn.Module): class TemporalAttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None): def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None, add_positional_conv=None):
super().__init__() super().__init__()
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
self.positional_embedding_proj = torch.nn.Sequential( self.positional_embedding_proj = torch.nn.Sequential(
torch.nn.Linear(in_channels, in_channels * 4), torch.nn.Linear(in_channels, in_channels * 4),
torch.nn.SiLU(), torch.nn.SiLU(),
torch.nn.Linear(in_channels * 4, in_channels) torch.nn.Linear(in_channels * 4, in_channels)
) )
if add_positional_conv is not None:
self.positional_embedding = TrainableTemporalTimesteps(in_channels, True, 0, add_positional_conv)
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1, padding_mode="reflect")
else:
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
self.positional_conv = None
self.norm_in = torch.nn.LayerNorm(in_channels) self.norm_in = torch.nn.LayerNorm(in_channels)
self.act_fn_in = GEGLU(in_channels, in_channels * 4) self.act_fn_in = GEGLU(in_channels, in_channels * 4)
@@ -137,14 +177,14 @@ class TemporalAttentionBlock(torch.nn.Module):
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
batch, inner_dim, height, width = hidden_states.shape batch, inner_dim, height, width = hidden_states.shape
pos_emb = torch.arange(batch, dtype=hidden_states.dtype) pos_emb = torch.arange(batch)
if batch > 25:
pos_emb *= 25 / batch
pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device) pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device)
pos_emb = self.positional_embedding_proj(pos_emb)[None, :, :] pos_emb = self.positional_embedding_proj(pos_emb)
hidden_states = hidden_states.permute(2, 3, 0, 1).reshape(height * width, batch, inner_dim) hidden_states = rearrange(hidden_states, "T C H W -> 1 C T H W") + rearrange(pos_emb, "T C -> 1 C T 1 1")
hidden_states = hidden_states + pos_emb if self.positional_conv is not None:
hidden_states = self.positional_conv(hidden_states)
hidden_states = rearrange(hidden_states[0], "C T H W -> (H W) T C")
residual = hidden_states residual = hidden_states
hidden_states = self.norm_in(hidden_states) hidden_states = self.norm_in(hidden_states)
@@ -193,7 +233,7 @@ class PopMixBlock(torch.nn.Module):
class SVDUNet(torch.nn.Module): class SVDUNet(torch.nn.Module):
def __init__(self): def __init__(self, add_positional_conv=None):
super().__init__() super().__init__()
self.time_proj = Timesteps(320) self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential( self.time_embedding = torch.nn.Sequential(
@@ -211,29 +251,29 @@ class SVDUNet(torch.nn.Module):
self.blocks = torch.nn.ModuleList([ self.blocks = torch.nn.ModuleList([
# CrossAttnDownBlockSpatioTemporal # CrossAttnDownBlockSpatioTemporal
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(), AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(), AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
DownSampler(320), PushBlock(), DownSampler(320), PushBlock(),
# CrossAttnDownBlockSpatioTemporal # CrossAttnDownBlockSpatioTemporal
ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(), AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(), AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
DownSampler(640), PushBlock(), DownSampler(640), PushBlock(),
# CrossAttnDownBlockSpatioTemporal # CrossAttnDownBlockSpatioTemporal
ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(), AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(), AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
DownSampler(1280), PushBlock(), DownSampler(1280), PushBlock(),
# DownBlockSpatioTemporal # DownBlockSpatioTemporal
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
# UNetMidBlockSpatioTemporal # UNetMidBlockSpatioTemporal
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
# UpBlockSpatioTemporal # UpBlockSpatioTemporal
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
@@ -241,28 +281,28 @@ class SVDUNet(torch.nn.Module):
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
UpSampler(1280), UpSampler(1280),
# CrossAttnUpBlockSpatioTemporal # CrossAttnUpBlockSpatioTemporal
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
UpSampler(1280), UpSampler(1280),
# CrossAttnUpBlockSpatioTemporal # CrossAttnUpBlockSpatioTemporal
PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
UpSampler(640), UpSampler(640),
# CrossAttnUpBlockSpatioTemporal # CrossAttnUpBlockSpatioTemporal
PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
]) ])
self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True) self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True)
@@ -327,7 +367,7 @@ class SVDUNet(torch.nn.Module):
return values return values
def forward(self, sample, timestep, encoder_hidden_states, add_time_id, **kwargs): def forward(self, sample, timestep, encoder_hidden_states, add_time_id, use_gradient_checkpointing=False, **kwargs):
# 1. time # 1. time
timestep = torch.tensor((timestep,)).to(sample.device) timestep = torch.tensor((timestep,)).to(sample.device)
t_emb = self.time_proj(timestep).to(sample.dtype) t_emb = self.time_proj(timestep).to(sample.dtype)
@@ -346,8 +386,19 @@ class SVDUNet(torch.nn.Module):
res_stack = [hidden_states] res_stack = [hidden_states]
# 3. blocks # 3. blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock) or isinstance(block, PopMixBlock)):
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 4. output # 4. output
hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_norm_out(hidden_states)
@@ -497,7 +548,7 @@ class SVDUNetStateDictConverter:
return state_dict_ return state_dict_
def from_civitai(self, state_dict): def from_civitai(self, state_dict, add_positional_conv=None):
rename_dict = { rename_dict = {
"model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias", "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
"model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight", "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
@@ -1935,4 +1986,18 @@ class SVDUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name: if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze() param = param.squeeze()
state_dict_[rename_dict[name]] = param state_dict_[rename_dict[name]] = param
if add_positional_conv is not None:
extra_names = [
"blocks.7.positional_conv", "blocks.17.positional_conv", "blocks.29.positional_conv", "blocks.39.positional_conv",
"blocks.51.positional_conv", "blocks.61.positional_conv", "blocks.83.positional_conv", "blocks.113.positional_conv",
"blocks.123.positional_conv", "blocks.133.positional_conv", "blocks.144.positional_conv", "blocks.154.positional_conv",
"blocks.164.positional_conv", "blocks.175.positional_conv", "blocks.185.positional_conv", "blocks.195.positional_conv",
]
extra_channels = [320, 320, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320]
for name, channels in zip(extra_names, extra_channels):
weight = torch.zeros((channels, channels, 3, 3, 3))
weight[:,:,1,1,1] = torch.eye(channels, channels)
bias = torch.zeros((channels,))
state_dict_[name + ".weight"] = weight
state_dict_[name + ".bias"] = bias
return state_dict_ return state_dict_

View File

@@ -109,6 +109,14 @@ class SVDVideoPipeline(torch.nn.Module):
return noise_pred return noise_pred
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
if post_normalize:
mean, std = latents.mean(), latents.std()
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
latents = latents * contrast_enhance_scale
return latents
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
@@ -126,6 +134,8 @@ class SVDVideoPipeline(torch.nn.Module):
motion_bucket_id=127, motion_bucket_id=127,
noise_aug_strength=0.02, noise_aug_strength=0.02,
num_inference_steps=20, num_inference_steps=20,
post_normalize=True,
contrast_enhance_scale=1.2,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
@@ -178,6 +188,7 @@ class SVDVideoPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image # Decode image
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd) video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
video = self.tensor2video(video) video = self.tensor2video(video)

View File

@@ -43,3 +43,17 @@ class ContinuousODEScheduler():
sigma = self.sigmas[timestep_id] sigma = self.sigmas[timestep_id]
sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt() sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
return sample return sample
def training_target(self, sample, noise, timestep):
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
weight = (1 + sigma*sigma).sqrt() / sigma
return weight

View File

@@ -0,0 +1,83 @@
from diffsynth import save_video, ModelManager, SVDVideoPipeline, HunyuanDiTImagePipeline
from diffsynth import ModelManager
import torch, os
def generate_image():
# Load models
os.environ["TOKENIZERS_PARALLELISM"] = "True"
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_models([
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
])
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
# Generate an image
torch.manual_seed(0)
image = pipe(
prompt="bonfire, on the stone",
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
num_inference_steps=50, height=1024, width=1024,
)
return image
def generate_video(image):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_models([
"models/stable_video_diffusion/svd_xt.safetensors",
"models/stable_video_diffusion/model.fp16.safetensors"
])
pipe = SVDVideoPipeline.from_model_manager(model_manager)
# Generate a video
torch.manual_seed(1)
video = pipe(
input_image=image.resize((512, 512)),
num_frames=128, fps=30, height=512, width=512,
motion_bucket_id=127,
num_inference_steps=50,
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
)
return video
def upscale_video(image, video):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
model_manager.load_models([
"models/stable_video_diffusion/svd_xt.safetensors",
"models/stable_video_diffusion/model.fp16.safetensors"
])
pipe = SVDVideoPipeline.from_model_manager(model_manager)
# Generate a video
torch.manual_seed(2)
video = pipe(
input_image=image.resize((1024, 1024)),
input_video=[frame.resize((1024, 1024)) for frame in video], denoising_strength=0.5,
num_frames=128, fps=30, height=1024, width=1024,
motion_bucket_id=127,
num_inference_steps=25,
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
)
return video
# We use Hunyuan DiT to generate the first frame.
# If you want to use your own image,
# please use `image = Image.open("your_image_file.png")` to replace the following code.
image = generate_image()
image.save("image.png")
# Now, generate a video with resolution of 512.
video = generate_video(image)
save_video(video, "video_512.mp4", fps=30)
# Upscale the video.
video = upscale_video(image, video)
save_video(video, "video_1024.mp4", fps=30)

View File

@@ -0,0 +1,16 @@
# ExVideo
ExVideo is a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
* [Source Code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ExVideo)
* Technical report
* Extended models
* [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
* [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)
## Example: Text-to-video via extended Stable Video Diffusion
Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd.py](./ExVideo_svd.py).
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc