mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
release ExVideo
This commit is contained in:
@@ -110,7 +110,11 @@ class ModelManager:
|
|||||||
param_name = "quant_conv.weight"
|
param_name = "quant_conv.weight"
|
||||||
return param_name in state_dict
|
return param_name in state_dict
|
||||||
|
|
||||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
||||||
|
param_name = "blocks.185.positional_embedding.embeddings"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
|
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||||
component_dict = {
|
component_dict = {
|
||||||
"image_encoder": SVDImageEncoder,
|
"image_encoder": SVDImageEncoder,
|
||||||
"unet": SVDUNet,
|
"unet": SVDUNet,
|
||||||
@@ -120,8 +124,12 @@ class ModelManager:
|
|||||||
if components is None:
|
if components is None:
|
||||||
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
||||||
for component in components:
|
for component in components:
|
||||||
self.model[component] = component_dict[component]()
|
if component == "unet":
|
||||||
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
||||||
|
else:
|
||||||
|
self.model[component] = component_dict[component]()
|
||||||
|
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
||||||
self.model[component].to(self.torch_dtype).to(self.device)
|
self.model[component].to(self.torch_dtype).to(self.device)
|
||||||
self.model_path[component] = file_path
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
@@ -305,6 +313,16 @@ class ModelManager:
|
|||||||
self.model[component] = model
|
self.model[component] = model
|
||||||
self.model_path[component] = file_path
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
||||||
|
unet_state_dict = self.model["unet"].state_dict()
|
||||||
|
self.model["unet"].to("cpu")
|
||||||
|
del self.model["unet"]
|
||||||
|
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
||||||
|
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
||||||
|
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
||||||
|
self.model["unet"].load_state_dict(state_dict, strict=False)
|
||||||
|
self.model["unet"].to(self.torch_dtype).to(self.device)
|
||||||
|
|
||||||
def search_for_embeddings(self, state_dict):
|
def search_for_embeddings(self, state_dict):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
@@ -370,6 +388,8 @@ class ModelManager:
|
|||||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||||
elif self.is_diffusers_vae(state_dict):
|
elif self.is_diffusers_vae(state_dict):
|
||||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
||||||
|
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
||||||
|
|
||||||
def load_models(self, file_path_list, lora_alphas=[]):
|
def load_models(self, file_path_list, lora_alphas=[]):
|
||||||
for file_path in file_path_list:
|
for file_path in file_path_list:
|
||||||
|
|||||||
@@ -97,17 +97,57 @@ class TemporalTimesteps(torch.nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class TrainableTemporalTimesteps(torch.nn.Module):
|
||||||
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, num_frames: int):
|
||||||
|
super().__init__()
|
||||||
|
timesteps = PositionalID()(num_frames)
|
||||||
|
embeddings = get_timestep_embedding(timesteps, num_channels, flip_sin_to_cos, downscale_freq_shift)
|
||||||
|
self.embeddings = torch.nn.Parameter(embeddings)
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
t_emb = self.embeddings[timesteps]
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalID(torch.nn.Module):
|
||||||
|
def __init__(self, max_id=25, repeat_length=20):
|
||||||
|
super().__init__()
|
||||||
|
self.max_id = max_id
|
||||||
|
self.repeat_length = repeat_length
|
||||||
|
|
||||||
|
def frame_id_to_position_id(self, frame_id):
|
||||||
|
if frame_id < self.max_id:
|
||||||
|
position_id = frame_id
|
||||||
|
else:
|
||||||
|
position_id = (frame_id - self.max_id) % (self.repeat_length * 2)
|
||||||
|
if position_id < self.repeat_length:
|
||||||
|
position_id = self.max_id - 2 - position_id
|
||||||
|
else:
|
||||||
|
position_id = self.max_id - 2 * self.repeat_length + position_id
|
||||||
|
return position_id
|
||||||
|
|
||||||
|
def forward(self, num_frames, pivot_frame_id=0):
|
||||||
|
position_ids = [self.frame_id_to_position_id(abs(i-pivot_frame_id)) for i in range(num_frames)]
|
||||||
|
position_ids = torch.IntTensor(position_ids)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
|
||||||
class TemporalAttentionBlock(torch.nn.Module):
|
class TemporalAttentionBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None):
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
|
|
||||||
self.positional_embedding_proj = torch.nn.Sequential(
|
self.positional_embedding_proj = torch.nn.Sequential(
|
||||||
torch.nn.Linear(in_channels, in_channels * 4),
|
torch.nn.Linear(in_channels, in_channels * 4),
|
||||||
torch.nn.SiLU(),
|
torch.nn.SiLU(),
|
||||||
torch.nn.Linear(in_channels * 4, in_channels)
|
torch.nn.Linear(in_channels * 4, in_channels)
|
||||||
)
|
)
|
||||||
|
if add_positional_conv is not None:
|
||||||
|
self.positional_embedding = TrainableTemporalTimesteps(in_channels, True, 0, add_positional_conv)
|
||||||
|
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1, padding_mode="reflect")
|
||||||
|
else:
|
||||||
|
self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
|
||||||
|
self.positional_conv = None
|
||||||
|
|
||||||
self.norm_in = torch.nn.LayerNorm(in_channels)
|
self.norm_in = torch.nn.LayerNorm(in_channels)
|
||||||
self.act_fn_in = GEGLU(in_channels, in_channels * 4)
|
self.act_fn_in = GEGLU(in_channels, in_channels * 4)
|
||||||
@@ -137,14 +177,14 @@ class TemporalAttentionBlock(torch.nn.Module):
|
|||||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||||
|
|
||||||
batch, inner_dim, height, width = hidden_states.shape
|
batch, inner_dim, height, width = hidden_states.shape
|
||||||
pos_emb = torch.arange(batch, dtype=hidden_states.dtype)
|
pos_emb = torch.arange(batch)
|
||||||
if batch > 25:
|
|
||||||
pos_emb *= 25 / batch
|
|
||||||
pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device)
|
pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
pos_emb = self.positional_embedding_proj(pos_emb)[None, :, :]
|
pos_emb = self.positional_embedding_proj(pos_emb)
|
||||||
|
|
||||||
hidden_states = hidden_states.permute(2, 3, 0, 1).reshape(height * width, batch, inner_dim)
|
hidden_states = rearrange(hidden_states, "T C H W -> 1 C T H W") + rearrange(pos_emb, "T C -> 1 C T 1 1")
|
||||||
hidden_states = hidden_states + pos_emb
|
if self.positional_conv is not None:
|
||||||
|
hidden_states = self.positional_conv(hidden_states)
|
||||||
|
hidden_states = rearrange(hidden_states[0], "C T H W -> (H W) T C")
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm_in(hidden_states)
|
hidden_states = self.norm_in(hidden_states)
|
||||||
@@ -193,7 +233,7 @@ class PopMixBlock(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SVDUNet(torch.nn.Module):
|
class SVDUNet(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_proj = Timesteps(320)
|
self.time_proj = Timesteps(320)
|
||||||
self.time_embedding = torch.nn.Sequential(
|
self.time_embedding = torch.nn.Sequential(
|
||||||
@@ -211,29 +251,29 @@ class SVDUNet(torch.nn.Module):
|
|||||||
|
|
||||||
self.blocks = torch.nn.ModuleList([
|
self.blocks = torch.nn.ModuleList([
|
||||||
# CrossAttnDownBlockSpatioTemporal
|
# CrossAttnDownBlockSpatioTemporal
|
||||||
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(),
|
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
|
||||||
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(),
|
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
|
||||||
DownSampler(320), PushBlock(),
|
DownSampler(320), PushBlock(),
|
||||||
# CrossAttnDownBlockSpatioTemporal
|
# CrossAttnDownBlockSpatioTemporal
|
||||||
ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(),
|
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
|
||||||
ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(),
|
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
|
||||||
DownSampler(640), PushBlock(),
|
DownSampler(640), PushBlock(),
|
||||||
# CrossAttnDownBlockSpatioTemporal
|
# CrossAttnDownBlockSpatioTemporal
|
||||||
ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(),
|
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
|
||||||
ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(),
|
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
|
||||||
DownSampler(1280), PushBlock(),
|
DownSampler(1280), PushBlock(),
|
||||||
# DownBlockSpatioTemporal
|
# DownBlockSpatioTemporal
|
||||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||||
# UNetMidBlockSpatioTemporal
|
# UNetMidBlockSpatioTemporal
|
||||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||||
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
||||||
# UpBlockSpatioTemporal
|
# UpBlockSpatioTemporal
|
||||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
||||||
@@ -241,28 +281,28 @@ class SVDUNet(torch.nn.Module):
|
|||||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
|
||||||
UpSampler(1280),
|
UpSampler(1280),
|
||||||
# CrossAttnUpBlockSpatioTemporal
|
# CrossAttnUpBlockSpatioTemporal
|
||||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||||
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||||
PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280),
|
AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
|
||||||
UpSampler(1280),
|
UpSampler(1280),
|
||||||
# CrossAttnUpBlockSpatioTemporal
|
# CrossAttnUpBlockSpatioTemporal
|
||||||
PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
|
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
|
||||||
PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
|
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
|
||||||
PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640),
|
AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
|
||||||
UpSampler(640),
|
UpSampler(640),
|
||||||
# CrossAttnUpBlockSpatioTemporal
|
# CrossAttnUpBlockSpatioTemporal
|
||||||
PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
|
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
|
||||||
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
|
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
|
||||||
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
|
||||||
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320),
|
AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
|
||||||
])
|
])
|
||||||
|
|
||||||
self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True)
|
self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True)
|
||||||
@@ -327,7 +367,7 @@ class SVDUNet(torch.nn.Module):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def forward(self, sample, timestep, encoder_hidden_states, add_time_id, **kwargs):
|
def forward(self, sample, timestep, encoder_hidden_states, add_time_id, use_gradient_checkpointing=False, **kwargs):
|
||||||
# 1. time
|
# 1. time
|
||||||
timestep = torch.tensor((timestep,)).to(sample.device)
|
timestep = torch.tensor((timestep,)).to(sample.device)
|
||||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||||
@@ -346,8 +386,19 @@ class SVDUNet(torch.nn.Module):
|
|||||||
res_stack = [hidden_states]
|
res_stack = [hidden_states]
|
||||||
|
|
||||||
# 3. blocks
|
# 3. blocks
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock) or isinstance(block, PopMixBlock)):
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
# 4. output
|
# 4. output
|
||||||
hidden_states = self.conv_norm_out(hidden_states)
|
hidden_states = self.conv_norm_out(hidden_states)
|
||||||
@@ -497,7 +548,7 @@ class SVDUNetStateDictConverter:
|
|||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict, add_positional_conv=None):
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
"model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
|
"model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
|
||||||
"model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
|
"model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
|
||||||
@@ -1935,4 +1986,18 @@ class SVDUNetStateDictConverter:
|
|||||||
if ".proj_in." in name or ".proj_out." in name:
|
if ".proj_in." in name or ".proj_out." in name:
|
||||||
param = param.squeeze()
|
param = param.squeeze()
|
||||||
state_dict_[rename_dict[name]] = param
|
state_dict_[rename_dict[name]] = param
|
||||||
|
if add_positional_conv is not None:
|
||||||
|
extra_names = [
|
||||||
|
"blocks.7.positional_conv", "blocks.17.positional_conv", "blocks.29.positional_conv", "blocks.39.positional_conv",
|
||||||
|
"blocks.51.positional_conv", "blocks.61.positional_conv", "blocks.83.positional_conv", "blocks.113.positional_conv",
|
||||||
|
"blocks.123.positional_conv", "blocks.133.positional_conv", "blocks.144.positional_conv", "blocks.154.positional_conv",
|
||||||
|
"blocks.164.positional_conv", "blocks.175.positional_conv", "blocks.185.positional_conv", "blocks.195.positional_conv",
|
||||||
|
]
|
||||||
|
extra_channels = [320, 320, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320]
|
||||||
|
for name, channels in zip(extra_names, extra_channels):
|
||||||
|
weight = torch.zeros((channels, channels, 3, 3, 3))
|
||||||
|
weight[:,:,1,1,1] = torch.eye(channels, channels)
|
||||||
|
bias = torch.zeros((channels,))
|
||||||
|
state_dict_[name + ".weight"] = weight
|
||||||
|
state_dict_[name + ".bias"] = bias
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|||||||
@@ -109,6 +109,14 @@ class SVDVideoPipeline(torch.nn.Module):
|
|||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
|
||||||
|
if post_normalize:
|
||||||
|
mean, std = latents.mean(), latents.std()
|
||||||
|
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
|
||||||
|
latents = latents * contrast_enhance_scale
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -126,6 +134,8 @@ class SVDVideoPipeline(torch.nn.Module):
|
|||||||
motion_bucket_id=127,
|
motion_bucket_id=127,
|
||||||
noise_aug_strength=0.02,
|
noise_aug_strength=0.02,
|
||||||
num_inference_steps=20,
|
num_inference_steps=20,
|
||||||
|
post_normalize=True,
|
||||||
|
contrast_enhance_scale=1.2,
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
@@ -178,6 +188,7 @@ class SVDVideoPipeline(torch.nn.Module):
|
|||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||||
|
|
||||||
# Decode image
|
# Decode image
|
||||||
|
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
|
||||||
video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
|
video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
|
||||||
video = self.tensor2video(video)
|
video = self.tensor2video(video)
|
||||||
|
|
||||||
|
|||||||
@@ -43,3 +43,17 @@ class ContinuousODEScheduler():
|
|||||||
sigma = self.sigmas[timestep_id]
|
sigma = self.sigmas[timestep_id]
|
||||||
sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
|
sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def training_target(self, sample, noise, timestep):
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
def training_weight(self, timestep):
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
weight = (1 + sigma*sigma).sqrt() / sigma
|
||||||
|
return weight
|
||||||
|
|||||||
83
examples/ExVideo/ExVideo_svd.py
Normal file
83
examples/ExVideo/ExVideo_svd.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from diffsynth import save_video, ModelManager, SVDVideoPipeline, HunyuanDiTImagePipeline
|
||||||
|
from diffsynth import ModelManager
|
||||||
|
import torch, os
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image():
|
||||||
|
# Load models
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
||||||
|
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
||||||
|
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
||||||
|
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
||||||
|
])
|
||||||
|
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
# Generate an image
|
||||||
|
torch.manual_seed(0)
|
||||||
|
image = pipe(
|
||||||
|
prompt="bonfire, on the stone",
|
||||||
|
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
|
||||||
|
num_inference_steps=50, height=1024, width=1024,
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def generate_video(image):
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/stable_video_diffusion/svd_xt.safetensors",
|
||||||
|
"models/stable_video_diffusion/model.fp16.safetensors"
|
||||||
|
])
|
||||||
|
pipe = SVDVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
# Generate a video
|
||||||
|
torch.manual_seed(1)
|
||||||
|
video = pipe(
|
||||||
|
input_image=image.resize((512, 512)),
|
||||||
|
num_frames=128, fps=30, height=512, width=512,
|
||||||
|
motion_bucket_id=127,
|
||||||
|
num_inference_steps=50,
|
||||||
|
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
|
||||||
|
)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_video(image, video):
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/stable_video_diffusion/svd_xt.safetensors",
|
||||||
|
"models/stable_video_diffusion/model.fp16.safetensors"
|
||||||
|
])
|
||||||
|
pipe = SVDVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
# Generate a video
|
||||||
|
torch.manual_seed(2)
|
||||||
|
video = pipe(
|
||||||
|
input_image=image.resize((1024, 1024)),
|
||||||
|
input_video=[frame.resize((1024, 1024)) for frame in video], denoising_strength=0.5,
|
||||||
|
num_frames=128, fps=30, height=1024, width=1024,
|
||||||
|
motion_bucket_id=127,
|
||||||
|
num_inference_steps=25,
|
||||||
|
min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2
|
||||||
|
)
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
# We use Hunyuan DiT to generate the first frame.
|
||||||
|
# If you want to use your own image,
|
||||||
|
# please use `image = Image.open("your_image_file.png")` to replace the following code.
|
||||||
|
image = generate_image()
|
||||||
|
image.save("image.png")
|
||||||
|
|
||||||
|
# Now, generate a video with resolution of 512.
|
||||||
|
video = generate_video(image)
|
||||||
|
save_video(video, "video_512.mp4", fps=30)
|
||||||
|
|
||||||
|
# Upscale the video.
|
||||||
|
video = upscale_video(image, video)
|
||||||
|
save_video(video, "video_1024.mp4", fps=30)
|
||||||
16
examples/ExVideo/README.md
Normal file
16
examples/ExVideo/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# ExVideo
|
||||||
|
|
||||||
|
ExVideo is a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||||
|
|
||||||
|
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
|
* [Source Code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ExVideo)
|
||||||
|
* Technical report
|
||||||
|
* Extended models
|
||||||
|
* [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
|
* [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
|
|
||||||
|
## Example: Text-to-video via extended Stable Video Diffusion
|
||||||
|
|
||||||
|
Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd.py](./ExVideo_svd.py).
|
||||||
|
|
||||||
|
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||||
Reference in New Issue
Block a user