diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index de6686f..21757e9 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -110,7 +110,11 @@ class ModelManager: param_name = "quant_conv.weight" return param_name in state_dict - def load_stable_video_diffusion(self, state_dict, components=None, file_path=""): + def is_ExVideo_StableVideoDiffusion(self, state_dict): + param_name = "blocks.185.positional_embedding.embeddings" + return param_name in state_dict + + def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None): component_dict = { "image_encoder": SVDImageEncoder, "unet": SVDUNet, @@ -120,8 +124,12 @@ class ModelManager: if components is None: components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"] for component in components: - self.model[component] = component_dict[component]() - self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + if component == "unet": + self.model[component] = component_dict[component](add_positional_conv=add_positional_conv) + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False) + else: + self.model[component] = component_dict[component]() + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) self.model[component].to(self.torch_dtype).to(self.device) self.model_path[component] = file_path @@ -305,6 +313,16 @@ class ModelManager: self.model[component] = model self.model_path[component] = file_path + def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""): + unet_state_dict = self.model["unet"].state_dict() + self.model["unet"].to("cpu") + del self.model["unet"] + add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0] + self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv) + self.model["unet"].load_state_dict(unet_state_dict, strict=False) + self.model["unet"].load_state_dict(state_dict, strict=False) + self.model["unet"].to(self.torch_dtype).to(self.device) + def search_for_embeddings(self, state_dict): embeddings = [] for k in state_dict: @@ -370,6 +388,8 @@ class ModelManager: self.load_hunyuan_dit(state_dict, file_path=file_path) elif self.is_diffusers_vae(state_dict): self.load_diffusers_vae(state_dict, file_path=file_path) + elif self.is_ExVideo_StableVideoDiffusion(state_dict): + self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path) def load_models(self, file_path_list, lora_alphas=[]): for file_path in file_path_list: diff --git a/diffsynth/models/svd_unet.py b/diffsynth/models/svd_unet.py index 1ec4a76..4bab083 100644 --- a/diffsynth/models/svd_unet.py +++ b/diffsynth/models/svd_unet.py @@ -95,19 +95,59 @@ class TemporalTimesteps(torch.nn.Module): downscale_freq_shift=self.downscale_freq_shift, ) return t_emb + + +class TrainableTemporalTimesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, num_frames: int): + super().__init__() + timesteps = PositionalID()(num_frames) + embeddings = get_timestep_embedding(timesteps, num_channels, flip_sin_to_cos, downscale_freq_shift) + self.embeddings = torch.nn.Parameter(embeddings) + + def forward(self, timesteps): + t_emb = self.embeddings[timesteps] + return t_emb + + +class PositionalID(torch.nn.Module): + def __init__(self, max_id=25, repeat_length=20): + super().__init__() + self.max_id = max_id + self.repeat_length = repeat_length + + def frame_id_to_position_id(self, frame_id): + if frame_id < self.max_id: + position_id = frame_id + else: + position_id = (frame_id - self.max_id) % (self.repeat_length * 2) + if position_id < self.repeat_length: + position_id = self.max_id - 2 - position_id + else: + position_id = self.max_id - 2 * self.repeat_length + position_id + return position_id + + def forward(self, num_frames, pivot_frame_id=0): + position_ids = [self.frame_id_to_position_id(abs(i-pivot_frame_id)) for i in range(num_frames)] + position_ids = torch.IntTensor(position_ids) + return position_ids class TemporalAttentionBlock(torch.nn.Module): - def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None): + def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None, add_positional_conv=None): super().__init__() - self.positional_embedding = TemporalTimesteps(in_channels, True, 0) self.positional_embedding_proj = torch.nn.Sequential( torch.nn.Linear(in_channels, in_channels * 4), torch.nn.SiLU(), torch.nn.Linear(in_channels * 4, in_channels) ) + if add_positional_conv is not None: + self.positional_embedding = TrainableTemporalTimesteps(in_channels, True, 0, add_positional_conv) + self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1, padding_mode="reflect") + else: + self.positional_embedding = TemporalTimesteps(in_channels, True, 0) + self.positional_conv = None self.norm_in = torch.nn.LayerNorm(in_channels) self.act_fn_in = GEGLU(in_channels, in_channels * 4) @@ -137,14 +177,14 @@ class TemporalAttentionBlock(torch.nn.Module): def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): batch, inner_dim, height, width = hidden_states.shape - pos_emb = torch.arange(batch, dtype=hidden_states.dtype) - if batch > 25: - pos_emb *= 25 / batch + pos_emb = torch.arange(batch) pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device) - pos_emb = self.positional_embedding_proj(pos_emb)[None, :, :] - - hidden_states = hidden_states.permute(2, 3, 0, 1).reshape(height * width, batch, inner_dim) - hidden_states = hidden_states + pos_emb + pos_emb = self.positional_embedding_proj(pos_emb) + + hidden_states = rearrange(hidden_states, "T C H W -> 1 C T H W") + rearrange(pos_emb, "T C -> 1 C T 1 1") + if self.positional_conv is not None: + hidden_states = self.positional_conv(hidden_states) + hidden_states = rearrange(hidden_states[0], "C T H W -> (H W) T C") residual = hidden_states hidden_states = self.norm_in(hidden_states) @@ -193,7 +233,7 @@ class PopMixBlock(torch.nn.Module): class SVDUNet(torch.nn.Module): - def __init__(self): + def __init__(self, add_positional_conv=None): super().__init__() self.time_proj = Timesteps(320) self.time_embedding = torch.nn.Sequential( @@ -211,29 +251,29 @@ class SVDUNet(torch.nn.Module): self.blocks = torch.nn.ModuleList([ # CrossAttnDownBlockSpatioTemporal - ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(), - ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(), + ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(), + ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(), DownSampler(320), PushBlock(), # CrossAttnDownBlockSpatioTemporal - ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(), - ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(), + ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(), + ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(), DownSampler(640), PushBlock(), # CrossAttnDownBlockSpatioTemporal - ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(), - ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(), + ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(), + ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(), DownSampler(1280), PushBlock(), # DownBlockSpatioTemporal - ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), - ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), # UNetMidBlockSpatioTemporal - ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), - AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), # UpBlockSpatioTemporal PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), @@ -241,28 +281,28 @@ class SVDUNet(torch.nn.Module): PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), UpSampler(1280), # CrossAttnUpBlockSpatioTemporal - PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), - PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), - PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), + PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), UpSampler(1280), # CrossAttnUpBlockSpatioTemporal - PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), - PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), - PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), + PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), + PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), + PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), UpSampler(640), # CrossAttnUpBlockSpatioTemporal - PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), - PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), - PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), - AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), + PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), + PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), + PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), ]) self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True) @@ -327,7 +367,7 @@ class SVDUNet(torch.nn.Module): return values - def forward(self, sample, timestep, encoder_hidden_states, add_time_id, **kwargs): + def forward(self, sample, timestep, encoder_hidden_states, add_time_id, use_gradient_checkpointing=False, **kwargs): # 1. time timestep = torch.tensor((timestep,)).to(sample.device) t_emb = self.time_proj(timestep).to(sample.dtype) @@ -346,8 +386,19 @@ class SVDUNet(torch.nn.Module): res_stack = [hidden_states] # 3. blocks + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward for i, block in enumerate(self.blocks): - hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock) or isinstance(block, PopMixBlock)): + hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, time_emb, text_emb, res_stack, + use_reentrant=False, + ) + else: + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) # 4. output hidden_states = self.conv_norm_out(hidden_states) @@ -497,7 +548,7 @@ class SVDUNetStateDictConverter: return state_dict_ - def from_civitai(self, state_dict): + def from_civitai(self, state_dict, add_positional_conv=None): rename_dict = { "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias", "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight", @@ -1935,4 +1986,18 @@ class SVDUNetStateDictConverter: if ".proj_in." in name or ".proj_out." in name: param = param.squeeze() state_dict_[rename_dict[name]] = param + if add_positional_conv is not None: + extra_names = [ + "blocks.7.positional_conv", "blocks.17.positional_conv", "blocks.29.positional_conv", "blocks.39.positional_conv", + "blocks.51.positional_conv", "blocks.61.positional_conv", "blocks.83.positional_conv", "blocks.113.positional_conv", + "blocks.123.positional_conv", "blocks.133.positional_conv", "blocks.144.positional_conv", "blocks.154.positional_conv", + "blocks.164.positional_conv", "blocks.175.positional_conv", "blocks.185.positional_conv", "blocks.195.positional_conv", + ] + extra_channels = [320, 320, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320] + for name, channels in zip(extra_names, extra_channels): + weight = torch.zeros((channels, channels, 3, 3, 3)) + weight[:,:,1,1,1] = torch.eye(channels, channels) + bias = torch.zeros((channels,)) + state_dict_[name + ".weight"] = weight + state_dict_[name + ".bias"] = bias return state_dict_ diff --git a/diffsynth/pipelines/stable_video_diffusion.py b/diffsynth/pipelines/stable_video_diffusion.py index a14aa85..431303c 100644 --- a/diffsynth/pipelines/stable_video_diffusion.py +++ b/diffsynth/pipelines/stable_video_diffusion.py @@ -107,6 +107,14 @@ class SVDVideoPipeline(torch.nn.Module): noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega) return noise_pred + + + def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0): + if post_normalize: + mean, std = latents.mean(), latents.std() + latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean + latents = latents * contrast_enhance_scale + return latents @torch.no_grad() @@ -126,6 +134,8 @@ class SVDVideoPipeline(torch.nn.Module): motion_bucket_id=127, noise_aug_strength=0.02, num_inference_steps=20, + post_normalize=True, + contrast_enhance_scale=1.2, progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -178,6 +188,7 @@ class SVDVideoPipeline(torch.nn.Module): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image + latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale) video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd) video = self.tensor2video(video) diff --git a/diffsynth/schedulers/continuous_ode.py b/diffsynth/schedulers/continuous_ode.py index 7cc5e54..e6cd837 100644 --- a/diffsynth/schedulers/continuous_ode.py +++ b/diffsynth/schedulers/continuous_ode.py @@ -43,3 +43,17 @@ class ContinuousODEScheduler(): sigma = self.sigmas[timestep_id] sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt() return sample + + + def training_target(self, sample, noise, timestep): + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise + return target + + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + weight = (1 + sigma*sigma).sqrt() / sigma + return weight diff --git a/examples/ExVideo/ExVideo_svd.py b/examples/ExVideo/ExVideo_svd.py new file mode 100644 index 0000000..cfe854a --- /dev/null +++ b/examples/ExVideo/ExVideo_svd.py @@ -0,0 +1,83 @@ +from diffsynth import save_video, ModelManager, SVDVideoPipeline, HunyuanDiTImagePipeline +from diffsynth import ModelManager +import torch, os + + +def generate_image(): + # Load models + os.environ["TOKENIZERS_PARALLELISM"] = "True" + model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") + model_manager.load_models([ + "models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin", + "models/HunyuanDiT/t2i/mt5/pytorch_model.bin", + "models/HunyuanDiT/t2i/model/pytorch_model_ema.pt", + "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin" + ]) + pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager) + + # Generate an image + torch.manual_seed(0) + image = pipe( + prompt="bonfire, on the stone", + negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,", + num_inference_steps=50, height=1024, width=1024, + ) + return image + + +def generate_video(image): + # Load models + model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") + model_manager.load_models([ + "models/stable_video_diffusion/svd_xt.safetensors", + "models/stable_video_diffusion/model.fp16.safetensors" + ]) + pipe = SVDVideoPipeline.from_model_manager(model_manager) + + # Generate a video + torch.manual_seed(1) + video = pipe( + input_image=image.resize((512, 512)), + num_frames=128, fps=30, height=512, width=512, + motion_bucket_id=127, + num_inference_steps=50, + min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2 + ) + return video + + +def upscale_video(image, video): + # Load models + model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") + model_manager.load_models([ + "models/stable_video_diffusion/svd_xt.safetensors", + "models/stable_video_diffusion/model.fp16.safetensors" + ]) + pipe = SVDVideoPipeline.from_model_manager(model_manager) + + # Generate a video + torch.manual_seed(2) + video = pipe( + input_image=image.resize((1024, 1024)), + input_video=[frame.resize((1024, 1024)) for frame in video], denoising_strength=0.5, + num_frames=128, fps=30, height=1024, width=1024, + motion_bucket_id=127, + num_inference_steps=25, + min_cfg_scale=2, max_cfg_scale=2, contrast_enhance_scale=1.2 + ) + return video + + +# We use Hunyuan DiT to generate the first frame. +# If you want to use your own image, +# please use `image = Image.open("your_image_file.png")` to replace the following code. +image = generate_image() +image.save("image.png") + +# Now, generate a video with resolution of 512. +video = generate_video(image) +save_video(video, "video_512.mp4", fps=30) + +# Upscale the video. +video = upscale_video(image, video) +save_video(video, "video_1024.mp4", fps=30) diff --git a/examples/ExVideo/README.md b/examples/ExVideo/README.md new file mode 100644 index 0000000..e457c75 --- /dev/null +++ b/examples/ExVideo/README.md @@ -0,0 +1,16 @@ +# ExVideo + +ExVideo is a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames. + +* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) +* [Source Code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ExVideo) +* Technical report +* Extended models + * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) + * [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1) + +## Example: Text-to-video via extended Stable Video Diffusion + +Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd.py](./ExVideo_svd.py). + +https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc