From 3681adc5acbd9e5f98fb507e51c9e372d8f6bfe7 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 17 Feb 2025 17:32:25 +0800 Subject: [PATCH] support stepvideo --- diffsynth/configs/model_config.py | 8 +- diffsynth/models/model_manager.py | 14 +- diffsynth/models/stepvideo_dit.py | 940 +++++++++++++++ diffsynth/models/stepvideo_text_encoder.py | 553 +++++++++ diffsynth/models/stepvideo_vae.py | 1030 +++++++++++++++++ diffsynth/pipelines/__init__.py | 1 + diffsynth/pipelines/step_video.py | 204 ++++ diffsynth/prompters/__init__.py | 1 + diffsynth/prompters/stepvideo_prompter.py | 56 + diffsynth/schedulers/flow_match.py | 7 +- examples/stepvideo/README.md | 13 + examples/stepvideo/stepvideo_text_to_video.py | 47 + 12 files changed, 2866 insertions(+), 8 deletions(-) create mode 100644 diffsynth/models/stepvideo_dit.py create mode 100644 diffsynth/models/stepvideo_text_encoder.py create mode 100644 diffsynth/models/stepvideo_vae.py create mode 100644 diffsynth/pipelines/step_video.py create mode 100644 diffsynth/prompters/stepvideo_prompter.py create mode 100644 examples/stepvideo/README.md create mode 100644 examples/stepvideo/stepvideo_text_to_video.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index c0f2215..9ac7e8e 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -51,6 +51,9 @@ from ..extensions.ESRGAN import RRDBNet from ..models.hunyuan_video_dit import HunyuanVideoDiT +from ..models.stepvideo_vae import StepVideoVAE +from ..models.stepvideo_dit import StepVideoModel + model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -103,6 +106,8 @@ model_loader_configs = [ (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"), (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), + (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"), + (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -115,7 +120,8 @@ huggingface_model_loader_configs = [ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"), - ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder") + ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"), + ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"), ] patch_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index c15d940..8216af0 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -158,7 +158,7 @@ class ModelDetectorFromSingleFile: def match(self, file_path="", state_dict={}): - if os.path.isdir(file_path): + if isinstance(file_path, str) and os.path.isdir(file_path): return False if len(state_dict) == 0: state_dict = load_state_dict(file_path) @@ -200,7 +200,7 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile): def match(self, file_path="", state_dict={}): - if os.path.isdir(file_path): + if isinstance(file_path, str) and os.path.isdir(file_path): return False if len(state_dict) == 0: state_dict = load_state_dict(file_path) @@ -243,7 +243,7 @@ class ModelDetectorFromHuggingfaceFolder: def match(self, file_path="", state_dict={}): - if os.path.isfile(file_path): + if not isinstance(file_path, str) or os.path.isfile(file_path): return False file_list = os.listdir(file_path) if "config.json" not in file_list: @@ -284,7 +284,7 @@ class ModelDetectorFromPatchedSingleFile: def match(self, file_path="", state_dict={}): - if os.path.isdir(file_path): + if not isinstance(file_path, str) or os.path.isdir(file_path): return False if len(state_dict) == 0: state_dict = load_state_dict(file_path) @@ -390,7 +390,11 @@ class ModelManager: print(f"Loading models from: {file_path}") if device is None: device = self.device if torch_dtype is None: torch_dtype = self.torch_dtype - if os.path.isfile(file_path): + if isinstance(file_path, list): + state_dict = {} + for path in file_path: + state_dict.update(load_state_dict(path)) + elif os.path.isfile(file_path): state_dict = load_state_dict(file_path) else: state_dict = None diff --git a/diffsynth/models/stepvideo_dit.py b/diffsynth/models/stepvideo_dit.py new file mode 100644 index 0000000..18403d6 --- /dev/null +++ b/diffsynth/models/stepvideo_dit.py @@ -0,0 +1,940 @@ +# Copyright 2025 StepFun Inc. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# ============================================================================== +from typing import Dict, Optional, Tuple +import torch, math +from torch import nn +from einops import rearrange, repeat +from tqdm import tqdm + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True + ): + super().__init__() + linear_cls = nn.Linear + + self.linear_1 = linear_cls( + in_channels, + time_embed_dim, + bias=sample_proj_bias, + ) + + if cond_proj_dim is not None: + self.cond_proj = linear_cls( + cond_proj_dim, + in_channels, + bias=False, + ) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + + self.linear_2 = linear_cls( + time_embed_dim, + time_embed_dim_out, + bias=sample_proj_bias, + ) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if self.use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, resolution=None, nframe=None, fps=None): + hidden_dtype = next(self.timestep_embedder.parameters()).dtype + + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + batch_size = timestep.shape[0] + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype) + nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + resolution_emb + nframe_emb + + if fps is not None: + fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype) + fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1) + conditioning = conditioning + fps_emb + else: + conditioning = timesteps_emb + + return conditioning + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs) + + out = self.linear(self.silu(embedded_timestep)) + + return out, embedded_timestep + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size): + super().__init__() + self.linear_1 = nn.Linear( + in_features, + hidden_size, + bias=True, + ) + self.act_1 = nn.GELU(approximate="tanh") + self.linear_2 = nn.Linear( + hidden_size, + hidden_size, + bias=True, + ) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class Attention(nn.Module): + def __init__(self): + super().__init__() + + def attn_processor(self, attn_type): + if attn_type == 'torch': + return self.torch_attn_func + elif attn_type == 'parallel': + return self.parallel_attn_func + else: + raise Exception('Not supported attention type...') + + def torch_attn_func( + self, + q, + k, + v, + attn_mask=None, + causal=False, + drop_rate=0.0, + **kwargs + ): + + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + + if attn_mask is not None and attn_mask.ndim == 3: ## no head + n_heads = q.shape[2] + attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) + + q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v)) + if attn_mask is not None: + attn_mask = attn_mask.to(q.device) + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal + ) + x = rearrange(x, 'b h s d -> b s h d') + return x + + +class RoPE1D: + def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0): + self.base = freq + self.F0 = F0 + self.scaling_factor = scaling_factor + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def __call__(self, tokens, positions): + """ + input: + * tokens: batch_size x ntokens x nheads x dim + * positions: batch_size x ntokens (t position of each token) + output: + * tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim) + """ + D = tokens.size(3) + assert positions.ndim == 2 # Batch, Seq + cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype) + tokens = self.apply_rope1d(tokens, positions, cos, sin) + return tokens + + +class RoPE3D(RoPE1D): + def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0): + super(RoPE3D, self).__init__(freq, F0, scaling_factor) + self.position_cache = {} + + def get_mesh_3d(self, rope_positions, bsz): + f, h, w = rope_positions + + if f"{f}-{h}-{w}" not in self.position_cache: + x = torch.arange(f, device='cpu') + y = torch.arange(h, device='cpu') + z = torch.arange(w, device='cpu') + self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3) + return self.position_cache[f"{f}-{h}-{w}"] + + def __call__(self, tokens, rope_positions, ch_split, parallel=False): + """ + input: + * tokens: batch_size x ntokens x nheads x dim + * rope_positions: list of (f, h, w) + output: + * tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim) + """ + assert sum(ch_split) == tokens.size(-1); + + mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0]) + out = [] + for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))): + cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype) + + if parallel: + pass + else: + mesh = mesh_grid[:, :, i].clone() + x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin) + out.append(x) + + tokens = torch.cat(out, dim=-1) + return tokens + + +class SelfAttention(Attention): + def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'): + super().__init__() + self.head_dim = head_dim + self.n_heads = hidden_dim // head_dim + + self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias) + self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias) + + self.with_rope = with_rope + self.with_qk_norm = with_qk_norm + if self.with_qk_norm: + self.q_norm = RMSNorm(head_dim, elementwise_affine=True) + self.k_norm = RMSNorm(head_dim, elementwise_affine=True) + + if self.with_rope: + self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0) + self.rope_ch_split = [64, 32, 32] + + self.core_attention = self.attn_processor(attn_type=attn_type) + self.parallel = attn_type=='parallel' + + def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True): + x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel) + return x + + def forward( + self, + x, + cu_seqlens=None, + max_seqlen=None, + rope_positions=None, + attn_mask=None + ): + xqkv = self.wqkv(x) + xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim) + + xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim + + if self.with_qk_norm: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + if self.with_rope: + xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel) + xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel) + + output = self.core_attention( + xq, + xk, + xv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + attn_mask=attn_mask + ) + output = rearrange(output, 'b s h d -> b s (h d)') + output = self.wo(output) + + return output + + +class CrossAttention(Attention): + def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'): + super().__init__() + self.head_dim = head_dim + self.n_heads = hidden_dim // head_dim + + self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias) + self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias) + self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias) + + self.with_qk_norm = with_qk_norm + if self.with_qk_norm: + self.q_norm = RMSNorm(head_dim, elementwise_affine=True) + self.k_norm = RMSNorm(head_dim, elementwise_affine=True) + + self.core_attention = self.attn_processor(attn_type=attn_type) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attn_mask=None + ): + xq = self.wq(x) + xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim) + + xkv = self.wkv(encoder_hidden_states) + xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim) + + xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim + + if self.with_qk_norm: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + output = self.core_attention( + xq, + xk, + xv, + attn_mask=attn_mask + ) + + output = rearrange(output, 'b s h d -> b s (h d)') + output = self.wo(output) + + return output + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(gate, approximate=self.approximate) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + inner_dim: Optional[int] = None, + dim_out: Optional[int] = None, + mult: int = 4, + bias: bool = False, + ): + super().__init__() + inner_dim = dim*mult if inner_dim is None else inner_dim + dim_out = dim if dim_out is None else dim_out + self.net = nn.ModuleList([ + GELU(dim, inner_dim, approximate="tanh", bias=bias), + nn.Identity(), + nn.Linear(inner_dim, dim_out, bias=bias) + ]) + + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +def modulate(x, scale, shift): + x = x * (1 + scale) + shift + return x + + +def gate(x, gate): + x = gate * x + return x + + +class StepVideoTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + attention_head_dim: int, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = False, + attention_type: str = 'parallel' + ): + super().__init__() + self.dim = dim + self.norm1 = nn.LayerNorm(dim, eps=norm_eps) + self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type) + + self.norm2 = nn.LayerNorm(dim, eps=norm_eps) + self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch') + + self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias) + + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5) + + @torch.no_grad() + def forward( + self, + q: torch.Tensor, + kv: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + attn_mask = None, + rope_positions: list = None, + ) -> torch.Tensor: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + torch.clone(chunk) for chunk in (self.scale_shift_table[None].to(dtype=q.dtype, device=q.device) + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1) + ) + + scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa) + + attn_q = self.attn1( + scale_shift_q, + rope_positions=rope_positions + ) + + q = gate(attn_q, gate_msa) + q + + attn_q = self.attn2( + q, + kv, + attn_mask + ) + + q = attn_q + q + + scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp) + + ff_output = self.ff(scale_shift_q) + + q = gate(ff_output, gate_mlp) + q + + return q + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + patch_size=64, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + ): + super().__init__() + + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + + def forward(self, latent): + latent = self.proj(latent).to(latent.dtype) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + + return latent + + +class StepVideoModel(torch.nn.Module): + def __init__( + self, + num_attention_heads: int = 48, + attention_head_dim: int = 128, + in_channels: int = 64, + out_channels: Optional[int] = 64, + num_layers: int = 48, + dropout: float = 0.0, + patch_size: int = 1, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + use_additional_conditions: Optional[bool] = False, + caption_channels: Optional[int]|list|tuple = [6144, 1024], + attention_type: Optional[str] = "torch", + ): + super().__init__() + + # Set some common variables used across the board. + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + + self.use_additional_conditions = use_additional_conditions + + self.pos_embed = PatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + StepVideoTransformerBlock( + dim=self.inner_dim, + attention_head_dim=attention_head_dim, + attention_type=attention_type + ) + for _ in range(num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) + self.patch_size = patch_size + + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + if isinstance(caption_channels, int): + caption_channel = caption_channels + else: + caption_channel, clip_channel = caption_channels + self.clip_projection = nn.Linear(clip_channel, self.inner_dim) + + self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channel, hidden_size=self.inner_dim + ) + + self.parallel = attention_type=='parallel' + + def patchfy(self, hidden_states): + hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w') + hidden_states = self.pos_embed(hidden_states) + return hidden_states + + def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen): + kv_seqlens = encoder_attention_mask.sum(dim=1).int() + mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device) + encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)] + for i, kv_len in enumerate(kv_seqlens): + mask[i, :, :kv_len] = 1 + return encoder_hidden_states, mask + + + def block_forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + rope_positions=None, + attn_mask=None, + parallel=True + ): + for block in tqdm(self.transformer_blocks, desc="Transformer blocks"): + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep=timestep, + attn_mask=attn_mask, + rope_positions=rope_positions + ) + + return hidden_states + + + @torch.inference_mode() + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + fps: torch.Tensor=None, + return_dict: bool = False, + ): + assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)" + + bsz, frame, _, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + hidden_states = self.patchfy(hidden_states) + len_frame = hidden_states.shape[1] + + if self.use_additional_conditions: + added_cond_kwargs = { + "resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype), + "nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype), + "fps": fps + } + else: + added_cond_kwargs = {} + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs=added_cond_kwargs + ) + + encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states)) + + if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'): + clip_embedding = self.clip_projection(encoder_hidden_states_2) + encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1) + + hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous() + encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame) + + hidden_states = self.block_forward( + hidden_states, + encoder_hidden_states, + timestep=timestep, + rope_positions=[frame, height, width], + attn_mask=attn_mask, + parallel=self.parallel + ) + + hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame) + + embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous() + + shift, scale = (self.scale_shift_table[None].to(dtype=embedded_timestep.dtype, device=embedded_timestep.device) + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + + hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q') + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + output = rearrange(output, '(b f) c h w -> b f c h w', f=frame) + + if return_dict: + return {'x': output} + return output + + @staticmethod + def state_dict_converter(): + return StepVideoDiTStateDictConverter() + + +class StepVideoDiTStateDictConverter: + def __init__(self): + super().__init__() + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict + + + \ No newline at end of file diff --git a/diffsynth/models/stepvideo_text_encoder.py b/diffsynth/models/stepvideo_text_encoder.py new file mode 100644 index 0000000..46aff0d --- /dev/null +++ b/diffsynth/models/stepvideo_text_encoder.py @@ -0,0 +1,553 @@ +# Copyright 2025 StepFun Inc. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# ============================================================================== +import os +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .stepvideo_dit import RMSNorm +from safetensors.torch import load_file +from transformers.modeling_utils import PretrainedConfig, PreTrainedModel +from einops import rearrange +import json +from typing import List +from functools import wraps +import warnings + + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None): + self.device = device + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, '__module__', None) == 'torch.nn.init': + if 'tensor' in kwargs: + return kwargs['tensor'] + else: + return args[0] + if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None: + kwargs['device'] = self.device + return func(*args, **kwargs) + + +def with_empty_init(func): + @wraps(func) + def wrapper(*args, **kwargs): + with EmptyInitOnDevice('cpu'): + return func(*args, **kwargs) + return wrapper + + + +class LLaMaEmbedding(nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + cfg, + ): + super().__init__() + self.hidden_size = cfg.hidden_size + self.params_dtype = cfg.params_dtype + self.fp32_residual_connection = cfg.fp32_residual_connection + self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32 + self.word_embeddings = torch.nn.Embedding( + cfg.padded_vocab_size, self.hidden_size, + ) + self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout) + + def forward(self, input_ids): + # Embeddings. + if self.embedding_weights_in_fp32: + self.word_embeddings = self.word_embeddings.to(torch.float32) + embeddings = self.word_embeddings(input_ids) + if self.embedding_weights_in_fp32: + embeddings = embeddings.to(self.params_dtype) + self.word_embeddings = self.word_embeddings.to(self.params_dtype) + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + + +class StepChatTokenizer: + """Step Chat Tokenizer""" + + def __init__( + self, model_file, name="StepChatTokenizer", + bot_token="<|BOT|>", # Begin of Turn + eot_token="<|EOT|>", # End of Turn + call_start_token="<|CALL_START|>", # Call Start + call_end_token="<|CALL_END|>", # Call End + think_start_token="<|THINK_START|>", # Think Start + think_end_token="<|THINK_END|>", # Think End + mask_start_token="<|MASK_1e69f|>", # Mask start + mask_end_token="<|UNMASK_1e69f|>", # Mask end + ): + import sentencepiece + + self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file) + + self._vocab = {} + self._inv_vocab = {} + + self._special_tokens = {} + self._inv_special_tokens = {} + + self._t5_tokens = [] + + for idx in range(self._tokenizer.get_piece_size()): + text = self._tokenizer.id_to_piece(idx) + self._inv_vocab[idx] = text + self._vocab[text] = idx + + if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx): + self._special_tokens[text] = idx + self._inv_special_tokens[idx] = text + + self._unk_id = self._tokenizer.unk_id() + self._bos_id = self._tokenizer.bos_id() + self._eos_id = self._tokenizer.eos_id() + + for token in [ + bot_token, eot_token, call_start_token, call_end_token, + think_start_token, think_end_token + ]: + assert token in self._vocab, f"Token '{token}' not found in tokenizer" + assert token in self._special_tokens, f"Token '{token}' is not a special token" + + for token in [mask_start_token, mask_end_token]: + assert token in self._vocab, f"Token '{token}' not found in tokenizer" + + self._bot_id = self._tokenizer.piece_to_id(bot_token) + self._eot_id = self._tokenizer.piece_to_id(eot_token) + self._call_start_id = self._tokenizer.piece_to_id(call_start_token) + self._call_end_id = self._tokenizer.piece_to_id(call_end_token) + self._think_start_id = self._tokenizer.piece_to_id(think_start_token) + self._think_end_id = self._tokenizer.piece_to_id(think_end_token) + self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token) + self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token) + + self._underline_id = self._tokenizer.piece_to_id("\u2581") + + @property + def vocab(self): + return self._vocab + + @property + def inv_vocab(self): + return self._inv_vocab + + @property + def vocab_size(self): + return self._tokenizer.vocab_size() + + def tokenize(self, text: str) -> List[int]: + return self._tokenizer.encode_as_ids(text) + + def detokenize(self, token_ids: List[int]) -> str: + return self._tokenizer.decode_ids(token_ids) + + +class Tokens: + def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None: + self.input_ids = input_ids + self.attention_mask = attention_mask + self.cu_input_ids = cu_input_ids + self.cu_seqlens = cu_seqlens + self.max_seq_len = max_seq_len + def to(self, device): + self.input_ids = self.input_ids.to(device) + self.attention_mask = self.attention_mask.to(device) + self.cu_input_ids = self.cu_input_ids.to(device) + self.cu_seqlens = self.cu_seqlens.to(device) + return self + +class Wrapped_StepChatTokenizer(StepChatTokenizer): + def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"): + # [bos, ..., eos, pad, pad, ..., pad] + self.BOS = 1 + self.EOS = 2 + self.PAD = 2 + out_tokens = [] + attn_mask = [] + if len(text) == 0: + part_tokens = [self.BOS] + [self.EOS] + valid_size = len(part_tokens) + if len(part_tokens) < max_length: + part_tokens += [self.PAD] * (max_length - valid_size) + out_tokens.append(part_tokens) + attn_mask.append([1]*valid_size+[0]*(max_length-valid_size)) + else: + for part in text: + part_tokens = self.tokenize(part) + part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos + part_tokens = [self.BOS] + part_tokens + [self.EOS] + valid_size = len(part_tokens) + if len(part_tokens) < max_length: + part_tokens += [self.PAD] * (max_length - valid_size) + out_tokens.append(part_tokens) + attn_mask.append([1]*valid_size+[0]*(max_length-valid_size)) + + out_tokens = torch.tensor(out_tokens, dtype=torch.long) + attn_mask = torch.tensor(attn_mask, dtype=torch.long) + + # padding y based on tp size + padded_len = 0 + padded_flag = True if padded_len > 0 else False + if padded_flag: + pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device) + pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device) + out_tokens = torch.cat([out_tokens, pad_tokens], dim=0) + attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0) + + # cu_seqlens + cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0) + seqlen = attn_mask.sum(dim=1).tolist() + cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32) + max_seq_len = max(seqlen) + return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len) + + + +def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True, + return_attn_probs=False, tp_group_rank=0, tp_group_size=1): + softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale + if hasattr(torch.ops.Optimus, "fwd"): + results = torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0] + else: + warnings.warn("Cannot load `torch.ops.Optimus.fwd`. Using `torch.nn.functional.scaled_dot_product_attention` instead.") + results = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=softmax_scale).transpose(1, 2) + return results + + +class FlashSelfAttention(torch.nn.Module): + def __init__( + self, + attention_dropout=0.0, + ): + super().__init__() + self.dropout_p = attention_dropout + + + def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None): + if cu_seqlens is None: + output = flash_attn_func(q, k, v, dropout_p=self.dropout_p) + else: + raise ValueError('cu_seqlens is not supported!') + + return output + + + +def safediv(n, d): + q, r = divmod(n, d) + assert r == 0 + return q + + +class MultiQueryAttention(nn.Module): + def __init__(self, cfg, layer_id=None): + super().__init__() + + self.head_dim = cfg.hidden_size // cfg.num_attention_heads + self.max_seq_len = cfg.seq_length + self.use_flash_attention = cfg.use_flash_attn + assert self.use_flash_attention, 'FlashAttention is required!' + + self.n_groups = cfg.num_attention_groups + self.tp_size = 1 + self.n_local_heads = cfg.num_attention_heads + self.n_local_groups = self.n_groups + + self.wqkv = nn.Linear( + cfg.hidden_size, + cfg.hidden_size + self.head_dim * 2 * self.n_groups, + bias=False, + ) + self.wo = nn.Linear( + cfg.hidden_size, + cfg.hidden_size, + bias=False, + ) + + assert self.use_flash_attention, 'non-Flash attention not supported yet.' + self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout) + + self.layer_id = layer_id + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + max_seq_len: Optional[torch.Tensor], + ): + seqlen, bsz, dim = x.shape + xqkv = self.wqkv(x) + + xq, xkv = torch.split( + xqkv, + (dim // self.tp_size, + self.head_dim*2*self.n_groups // self.tp_size + ), + dim=-1, + ) + + # gather on 1st dimention + xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim) + xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim) + xk, xv = xkv.chunk(2, -1) + + # rotary embedding + flash attn + xq = rearrange(xq, "s b h d -> b s h d") + xk = rearrange(xk, "s b h d -> b s h d") + xv = rearrange(xv, "s b h d -> b s h d") + + q_per_kv = self.n_local_heads // self.n_local_groups + if q_per_kv > 1: + b, s, h, d = xk.size() + if h == 1: + xk = xk.expand(b, s, q_per_kv, d) + xv = xv.expand(b, s, q_per_kv, d) + else: + ''' To cover the cases where h > 1, we have + the following implementation, which is equivalent to: + xk = xk.repeat_interleave(q_per_kv, dim=-2) + xv = xv.repeat_interleave(q_per_kv, dim=-2) + but can avoid calling aten::item() that involves cpu. + ''' + idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten() + xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous() + xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous() + + if self.use_flash_attention: + output = self.core_attention(xq, xk, xv, + cu_seqlens=cu_seqlens, + max_seq_len=max_seq_len) + # reduce-scatter only support first dimention now + output = rearrange(output, "b s h d -> s b (h d)").contiguous() + else: + xq, xk, xv = [ + rearrange(x, "b s ... -> s b ...").contiguous() + for x in (xq, xk, xv) + ] + output = self.core_attention(xq, xk, xv, mask) + output = self.wo(output) + return output + + + +class FeedForward(nn.Module): + def __init__( + self, + cfg, + dim: int, + hidden_dim: int, + layer_id: int, + multiple_of: int=256, + ): + super().__init__() + + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + self.swiglu = swiglu + + self.w1 = nn.Linear( + dim, + 2 * hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x): + x = self.swiglu(self.w1(x)) + output = self.w2(x) + return output + + + +class TransformerBlock(nn.Module): + def __init__( + self, cfg, layer_id: int + ): + super().__init__() + + self.n_heads = cfg.num_attention_heads + self.dim = cfg.hidden_size + self.head_dim = cfg.hidden_size // cfg.num_attention_heads + self.attention = MultiQueryAttention( + cfg, + layer_id=layer_id, + ) + + self.feed_forward = FeedForward( + cfg, + dim=cfg.hidden_size, + hidden_dim=cfg.ffn_hidden_size, + layer_id=layer_id, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm( + cfg.hidden_size, + eps=cfg.layernorm_epsilon, + ) + self.ffn_norm = RMSNorm( + cfg.hidden_size, + eps=cfg.layernorm_epsilon, + ) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + max_seq_len: Optional[torch.Tensor], + ): + residual = self.attention.forward( + self.attention_norm(x), mask, + cu_seqlens, max_seq_len + ) + h = x + residual + ffn_res = self.feed_forward.forward(self.ffn_norm(h)) + out = h + ffn_res + return out + + +class Transformer(nn.Module): + def __init__( + self, + config, + max_seq_size=8192, + ): + super().__init__() + self.num_layers = config.num_layers + self.layers = self._build_layers(config) + + def _build_layers(self, config): + layers = torch.nn.ModuleList() + for layer_id in range(self.num_layers): + layers.append( + TransformerBlock( + config, + layer_id=layer_id + 1 , + ) + ) + return layers + + def forward( + self, + hidden_states, + attention_mask, + cu_seqlens=None, + max_seq_len=None, + ): + + if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor): + max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu") + + for lid, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, + attention_mask, + cu_seqlens, + max_seq_len, + ) + return hidden_states + + +class Step1Model(PreTrainedModel): + config_class=PretrainedConfig + @with_empty_init + def __init__( + self, + config, + ): + super().__init__(config) + self.tok_embeddings = LLaMaEmbedding(config) + self.transformer = Transformer(config) + + def forward( + self, + input_ids=None, + attention_mask=None, + ): + + hidden_states = self.tok_embeddings(input_ids) + + hidden_states = self.transformer( + hidden_states, + attention_mask, + ) + return hidden_states + + + +class STEP1TextEncoder(torch.nn.Module): + def __init__(self, model_dir, max_length=320): + super(STEP1TextEncoder, self).__init__() + self.max_length = max_length + self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model')) + text_encoder = Step1Model.from_pretrained(model_dir) + self.text_encoder = text_encoder.eval().to(torch.bfloat16) + + @staticmethod + def from_pretrained(path, torch_dtype=torch.bfloat16): + model = STEP1TextEncoder(path).to(torch_dtype) + return model + + @torch.no_grad + def forward(self, prompts, with_mask=True, max_length=None, device="cuda"): + self.device = device + with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device): + if type(prompts) is str: + prompts = [prompts] + + txt_tokens = self.text_tokenizer( + prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + y = self.text_encoder( + txt_tokens.input_ids.to(self.device), + attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None + ) + y_mask = txt_tokens.attention_mask + return y.transpose(0,1), y_mask + diff --git a/diffsynth/models/stepvideo_vae.py b/diffsynth/models/stepvideo_vae.py new file mode 100644 index 0000000..ba46cac --- /dev/null +++ b/diffsynth/models/stepvideo_vae.py @@ -0,0 +1,1030 @@ +# Copyright 2025 StepFun Inc. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# ============================================================================== +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + + +def base_group_norm(x, norm_layer, act_silu=False, channel_last=False): + if hasattr(base_group_norm, 'spatial') and base_group_norm.spatial: + assert channel_last == True + x_shape = x.shape + x = x.flatten(0, 1) + if channel_last: + # Permute to NCHW format + x = x.permute(0, 3, 1, 2) + + out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps) + if act_silu: + out = F.silu(out) + + if channel_last: + # Permute back to NHWC format + out = out.permute(0, 2, 3, 1) + + out = out.view(x_shape) + else: + if channel_last: + # Permute to NCHW format + x = x.permute(0, 3, 1, 2) + out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps) + if act_silu: + out = F.silu(out) + if channel_last: + # Permute back to NHWC format + out = out.permute(0, 2, 3, 1) + return out + +def base_conv2d(x, conv_layer, channel_last=False, residual=None): + if channel_last: + x = x.permute(0, 3, 1, 2) # NHWC to NCHW + out = F.conv2d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding) + if residual is not None: + if channel_last: + residual = residual.permute(0, 3, 1, 2) # NHWC to NCHW + out += residual + if channel_last: + out = out.permute(0, 2, 3, 1) # NCHW to NHWC + return out + +def base_conv3d(x, conv_layer, channel_last=False, residual=None, only_return_output=False): + if only_return_output: + size = cal_outsize(x.shape, conv_layer.weight.shape, conv_layer.stride, conv_layer.padding) + return torch.empty(size, device=x.device, dtype=x.dtype) + if channel_last: + x = x.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW + out = F.conv3d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding) + if residual is not None: + if channel_last: + residual = residual.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW + out += residual + if channel_last: + out = out.permute(0, 2, 3, 4, 1) # NCDHW to NDHWC + return out + + +def cal_outsize(input_sizes, kernel_sizes, stride, padding): + stride_d, stride_h, stride_w = stride + padding_d, padding_h, padding_w = padding + dilation_d, dilation_h, dilation_w = 1, 1, 1 + + in_d = input_sizes[1] + in_h = input_sizes[2] + in_w = input_sizes[3] + in_channel = input_sizes[4] + + + kernel_d = kernel_sizes[2] + kernel_h = kernel_sizes[3] + kernel_w = kernel_sizes[4] + out_channels = kernel_sizes[0] + + out_d = calc_out_(in_d, padding_d, dilation_d, kernel_d, stride_d) + out_h = calc_out_(in_h, padding_h, dilation_h, kernel_h, stride_h) + out_w = calc_out_(in_w, padding_w, dilation_w, kernel_w, stride_w) + size = [input_sizes[0], out_d, out_h, out_w, out_channels] + return size + + + + +def calc_out_(in_size, padding, dilation, kernel, stride): + return (in_size + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1 + + + +def base_conv3d_channel_last(x, conv_layer, residual=None): + in_numel = x.numel() + out_numel = int(x.numel() * conv_layer.out_channels / conv_layer.in_channels) + if (in_numel >= 2**30) or (out_numel >= 2**30): + assert conv_layer.stride[0] == 1, "time split asks time stride = 1" + + B,T,H,W,C = x.shape + K = conv_layer.kernel_size[0] + + chunks = 4 + chunk_size = T // chunks + + if residual is None: + out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual, only_return_output=True) + else: + out_nhwc = residual + + assert B == 1 + outs = [] + for i in range(chunks): + if i == chunks-1: + xi = x[:1,chunk_size*i:] + out_nhwci = out_nhwc[:1,chunk_size*i:] + else: + xi = x[:1,chunk_size*i:chunk_size*(i+1)+K-1] + out_nhwci = out_nhwc[:1,chunk_size*i:chunk_size*(i+1)] + if residual is not None: + if i == chunks-1: + ri = residual[:1,chunk_size*i:] + else: + ri = residual[:1,chunk_size*i:chunk_size*(i+1)] + else: + ri = None + out_nhwci.copy_(base_conv3d(xi, conv_layer, channel_last=True, residual=ri)) + else: + out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual) + return out_nhwc + + + +class Upsample2D(nn.Module): + def __init__(self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + else: + assert "Not Supported" + self.conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + + def forward(self, x, output_size=None): + assert x.shape[-1] == self.channels + + if self.use_conv_transpose: + return self.conv(x) + + if output_size is None: + x = F.interpolate( + x.permute(0,3,1,2).to(memory_format=torch.channels_last), + scale_factor=2.0, mode='nearest').permute(0,2,3,1).contiguous() + else: + x = F.interpolate( + x.permute(0,3,1,2).to(memory_format=torch.channels_last), + size=output_size, mode='nearest').permute(0,2,3,1).contiguous() + + # x = self.conv(x) + x = base_conv2d(x, self.conv, channel_last=True) + return x + + +class Downsample2D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[-1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 0, 0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + + assert x.shape[-1] == self.channels + # x = self.conv(x) + x = base_conv2d(x, self.conv, channel_last=True) + return x + + + +class CausalConv(nn.Module): + def __init__(self, + chan_in, + chan_out, + kernel_size, + **kwargs + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = kernel_size if isinstance(kernel_size, tuple) else ((kernel_size,) * 3) + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.dilation = kwargs.pop('dilation', 1) + self.stride = kwargs.pop('stride', 1) + if isinstance(self.stride, int): + self.stride = (self.stride, 1, 1) + time_pad = self.dilation * (time_kernel_size - 1) + max((1 - self.stride[0]), 0) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0) + + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs) + self.is_first_run = True + + def forward(self, x, is_init=True, residual=None): + x = nn.functional.pad(x, + self.time_causal_padding if is_init else self.time_uncausal_padding) + + x = self.conv(x) + if residual is not None: + x.add_(residual) + return x + + +class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor = factor + assert out_channels * factor**3 % in_channels == 0 + self.repeats = out_channels * factor**3 // in_channels + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view(x.size(0), self.out_channels, self.factor, self.factor, self.factor, x.size(2), x.size(3), x.size(4)) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view(x.size(0), self.out_channels, x.size(2)*self.factor, x.size(4)*self.factor, x.size(6)*self.factor) + x = x[:, :, self.factor - 1:, :, :] + return x + +class ConvPixelShuffleUpSampleLayer3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + factor: int, + ): + super().__init__() + self.factor = factor + out_ratio = factor**3 + self.conv = CausalConv( + in_channels, + out_channels * out_ratio, + kernel_size=kernel_size + ) + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + x = self.conv(x, is_init) + x = self.pixel_shuffle_3d(x, self.factor) + return x + + @staticmethod + def pixel_shuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor: + batch_size, channels, depth, height, width = x.size() + new_channels = channels // (factor ** 3) + new_depth = depth * factor + new_height = height * factor + new_width = width * factor + + x = x.view(batch_size, new_channels, factor, factor, factor, depth, height, width) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view(batch_size, new_channels, new_depth, new_height, new_width) + x = x[:, :, factor - 1:, :, :] + return x + +class ConvPixelUnshuffleDownSampleLayer3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + factor: int, + ): + super().__init__() + self.factor = factor + out_ratio = factor**3 + assert out_channels % out_ratio == 0 + self.conv = CausalConv( + in_channels, + out_channels // out_ratio, + kernel_size=kernel_size + ) + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + x = self.conv(x, is_init) + x = self.pixel_unshuffle_3d(x, self.factor) + return x + + @staticmethod + def pixel_unshuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor: + pad = (0, 0, 0, 0, factor-1, 0) # (left, right, top, bottom, front, back) + x = F.pad(x, pad) + B, C, D, H, W = x.shape + x = x.view(B, C, D // factor, factor, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view(B, C * factor**3, D // factor, H // factor, W // factor) + return x + +class PixelUnshuffleChannelAveragingDownSampleLayer3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor = factor + assert in_channels * factor**3 % out_channels == 0 + self.group_size = in_channels * factor**3 // out_channels + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + pad = (0, 0, 0, 0, self.factor-1, 0) # (left, right, top, bottom, front, back) + x = F.pad(x, pad) + B, C, D, H, W = x.shape + x = x.view(B, C, D // self.factor, self.factor, H // self.factor, self.factor, W // self.factor, self.factor) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view(B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor) + x = x.view(B, self.out_channels, self.group_size, D // self.factor, H // self.factor, W // self.factor) + x = x.mean(dim=2) + return x + + def __init__( + self, + in_channels: int, + out_channels: int, + factor: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor = factor + assert in_channels * factor**3 % out_channels == 0 + self.group_size = in_channels * factor**3 // out_channels + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + pad = (0, 0, 0, 0, self.factor-1, 0) # (left, right, top, bottom, front, back) + x = F.pad(x, pad) + B, C, D, H, W = x.shape + x = x.view(B, C, D // self.factor, self.factor, H // self.factor, self.factor, W // self.factor, self.factor) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view(B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor) + x = x.view(B, self.out_channels, self.group_size, D // self.factor, H // self.factor, W // self.factor) + x = x.mean(dim=2) + return x + + + + +def base_group_norm_with_zero_pad(x, norm_layer, act_silu=True, pad_size=2): + out_shape = list(x.shape) + out_shape[1] += pad_size + out = torch.empty(out_shape, dtype=x.dtype, device=x.device) + out[:, pad_size:] = base_group_norm(x, norm_layer, act_silu=act_silu, channel_last=True) + out[:, :pad_size] = 0 + return out + + +class CausalConvChannelLast(CausalConv): + def __init__(self, + chan_in, + chan_out, + kernel_size, + **kwargs + ): + super().__init__( + chan_in, chan_out, kernel_size, **kwargs) + + self.time_causal_padding = (0, 0) + self.time_causal_padding + self.time_uncausal_padding = (0, 0) + self.time_uncausal_padding + + def forward(self, x, is_init=True, residual=None): + if self.is_first_run: + self.is_first_run = False + # self.conv.weight = nn.Parameter(self.conv.weight.permute(0,2,3,4,1).contiguous()) + + x = nn.functional.pad(x, + self.time_causal_padding if is_init else self.time_uncausal_padding) + + x = base_conv3d_channel_last(x, self.conv, residual=residual) + return x + +class CausalConvAfterNorm(CausalConv): + def __init__(self, + chan_in, + chan_out, + kernel_size, + **kwargs + ): + super().__init__( + chan_in, chan_out, kernel_size, **kwargs) + + if self.time_causal_padding == (1, 1, 1, 1, 2, 0): + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, padding=(0, 1, 1), **kwargs) + else: + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs) + self.is_first_run = True + + def forward(self, x, is_init=True, residual=None): + if self.is_first_run: + self.is_first_run = False + + if self.time_causal_padding == (1, 1, 1, 1, 2, 0): + pass + else: + x = nn.functional.pad(x, self.time_causal_padding).contiguous() + + x = base_conv3d_channel_last(x, self.conv, residual=residual) + return x + +class AttnBlock(nn.Module): + def __init__(self, + in_channels + ): + super().__init__() + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels) + self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + + def attention(self, x, is_init=True): + x = base_group_norm(x, self.norm, act_silu=False, channel_last=True) + q = self.q(x, is_init) + k = self.k(x, is_init) + v = self.v(x, is_init) + + b, t, h, w, c = q.shape + q, k, v = map(lambda x: rearrange(x, "b t h w c -> b 1 (t h w) c"), (q, k, v)) + x = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) + x = rearrange(x, "b 1 (t h w) c -> b t h w c", t=t, h=h, w=w) + + return x + + def forward(self, x): + x = x.permute(0,2,3,4,1).contiguous() + h = self.attention(x) + x = self.proj_out(h, residual=x) + x = x.permute(0,4,1,2,3) + return x + +class Resnet3DBlock(nn.Module): + def __init__(self, + in_channels, + out_channels=None, + temb_channels=512, + conv_shortcut=False, + ): + super().__init__() + + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels) + self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels) + self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3) + + assert conv_shortcut is False + self.use_conv_shortcut = conv_shortcut + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3) + else: + self.nin_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=1) + + def forward(self, x, temb=None, is_init=True): + x = x.permute(0,2,3,4,1).contiguous() + + h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2) + h = self.conv1(h) + if temb is not None: + h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None] + + x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x + + h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2) + x = self.conv2(h, residual=x) + + x = x.permute(0,4,1,2,3) + return x + + +class Downsample3D(nn.Module): + def __init__(self, + in_channels, + with_conv, + stride + ): + super().__init__() + + self.with_conv = with_conv + if with_conv: + self.conv = CausalConv(in_channels, in_channels, kernel_size=3, stride=stride) + + def forward(self, x, is_init=True): + if self.with_conv: + x = self.conv(x, is_init) + else: + x = nn.functional.avg_pool3d(x, kernel_size=2, stride=2) + return x + +class VideoEncoder(nn.Module): + def __init__(self, + ch=32, + ch_mult=(4, 8, 16, 16), + num_res_blocks=2, + in_channels=3, + z_channels=16, + double_z=True, + down_sampling_layer=[1, 2], + resamp_with_conv=True, + version=1, + ): + super().__init__() + + temb_ch = 0 + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + # downsampling + self.conv_in = CausalConv(in_channels, ch, kernel_size=3) + self.down_sampling_layer = down_sampling_layer + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.down_sampling_layer: + down.downsample = Downsample3D(block_in, resamp_with_conv, stride=(2, 2, 2)) + else: + down.downsample = Downsample2D(block_in, resamp_with_conv, padding=0) #DIFF + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in) + self.version = version + if version == 2: + channels = 4 * z_channels * 2 ** 3 + self.conv_patchify = ConvPixelUnshuffleDownSampleLayer3D(block_in, channels, kernel_size=3, factor=2) + self.shortcut_pathify = PixelUnshuffleChannelAveragingDownSampleLayer3D(block_in, channels, 2) + self.shortcut_out = PixelUnshuffleChannelAveragingDownSampleLayer3D(channels, 2 * z_channels if double_z else z_channels, 1) + self.conv_out = CausalConvChannelLast(channels, 2 * z_channels if double_z else z_channels, kernel_size=3) + else: + self.conv_out = CausalConvAfterNorm(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3) + + @torch.inference_mode() + def forward(self, x, video_frame_num, is_init=True): + # timestep embedding + temb = None + + t = video_frame_num + + # downsampling + h = self.conv_in(x, is_init) + + # make it real channel last, but behave like normal layout + h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb, is_init) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + + if i_level != self.num_resolutions - 1: + if isinstance(self.down[i_level].downsample, Downsample2D): + _, _, t, _, _ = h.shape + h = rearrange(h, "b c t h w -> (b t) h w c", t=t) + h = self.down[i_level].downsample(h) + h = rearrange(h, "(b t) h w c -> b c t h w", t=t) + else: + h = self.down[i_level].downsample(h, is_init) + + h = self.mid.block_1(h, temb, is_init) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb, is_init) + + h = h.permute(0,2,3,4,1).contiguous() # b c l h w -> b l h w c + if self.version == 2: + h = base_group_norm(h, self.norm_out, act_silu=True, channel_last=True) + h = h.permute(0,4,1,2,3).contiguous() + shortcut = self.shortcut_pathify(h, is_init) + h = self.conv_patchify(h, is_init) + h = h.add_(shortcut) + shortcut = self.shortcut_out(h, is_init).permute(0,2,3,4,1) + h = self.conv_out(h.permute(0,2,3,4,1).contiguous(), is_init) + h = h.add_(shortcut) + else: + h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2) + h = self.conv_out(h, is_init) + h = h.permute(0,4,1,2,3) # b l h w c -> b c l h w + + h = rearrange(h, "b c t h w -> b t c h w") + return h + + +class Res3DBlockUpsample(nn.Module): + def __init__(self, + input_filters, + num_filters, + down_sampling_stride, + down_sampling=False + ): + super().__init__() + + self.input_filters = input_filters + self.num_filters = num_filters + + self.act_ = nn.SiLU(inplace=True) + + self.conv1 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3]) + self.norm1 = nn.GroupNorm(32, num_filters) + + self.conv2 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3]) + self.norm2 = nn.GroupNorm(32, num_filters) + + self.down_sampling = down_sampling + if down_sampling: + self.down_sampling_stride = down_sampling_stride + else: + self.down_sampling_stride = [1, 1, 1] + + if num_filters != input_filters or down_sampling: + self.conv3 = CausalConvChannelLast(input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride) + self.norm3 = nn.GroupNorm(32, num_filters) + + def forward(self, x, is_init=False): + x = x.permute(0,2,3,4,1).contiguous() + + residual = x + + h = self.conv1(x, is_init) + h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True) + + h = self.conv2(h, is_init) + h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True) + + if self.down_sampling or self.num_filters != self.input_filters: + x = self.conv3(x, is_init) + x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True) + + h.add_(x) + h = self.act_(h) + if residual is not None: + h.add_(residual) + + h = h.permute(0,4,1,2,3) + return h + +class Upsample3D(nn.Module): + def __init__(self, + in_channels, + scale_factor=2 + ): + super().__init__() + + self.scale_factor = scale_factor + self.conv3d = Res3DBlockUpsample(input_filters=in_channels, + num_filters=in_channels, + down_sampling_stride=(1, 1, 1), + down_sampling=False) + + def forward(self, x, is_init=True, is_split=True): + b, c, t, h, w = x.shape + + # x = x.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3).to(memory_format=torch.channels_last_3d) + if is_split: + split_size = c // 8 + x_slices = torch.split(x, split_size, dim=1) + x = [nn.functional.interpolate(x, scale_factor=self.scale_factor) for x in x_slices] + x = torch.cat(x, dim=1) + else: + x = nn.functional.interpolate(x, scale_factor=self.scale_factor) + + x = self.conv3d(x, is_init) + return x + +class VideoDecoder(nn.Module): + def __init__(self, + ch=128, + z_channels=16, + out_channels=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + temporal_up_layers=[2, 3], + temporal_downsample=4, + resamp_with_conv=True, + version=1, + ): + super().__init__() + + temb_ch = 0 + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.temporal_downsample = temporal_downsample + + block_in = ch * ch_mult[self.num_resolutions - 1] + self.version = version + if version == 2: + channels = 4 * z_channels * 2 ** 3 + self.conv_in = CausalConv(z_channels, channels, kernel_size=3) + self.shortcut_in = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(z_channels, channels, 1) + self.conv_unpatchify = ConvPixelShuffleUpSampleLayer3D(channels, block_in, kernel_size=3, factor=2) + self.shortcut_unpathify = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(channels, block_in, 2) + else: + self.conv_in = CausalConv(z_channels, block_in, kernel_size=3) + + # middle + self.mid = nn.Module() + self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch) + + # upsampling + self.up_id = len(temporal_up_layers) + self.video_frame_num = 1 + self.cur_video_frame_num = self.video_frame_num // 2 ** self.up_id + 1 + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level in temporal_up_layers: + up.upsample = Upsample3D(block_in) + self.cur_video_frame_num = self.cur_video_frame_num * 2 + else: + up.upsample = Upsample2D(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in) + self.conv_out = CausalConvAfterNorm(block_in, out_channels, kernel_size=3) + + @torch.inference_mode() + def forward(self, z, is_init=True): + z = rearrange(z, "b t c h w -> b c t h w") + + h = self.conv_in(z, is_init=is_init) + if self.version == 2: + shortcut = self.shortcut_in(z, is_init=is_init) + h = h.add_(shortcut) + shortcut = self.shortcut_unpathify(h, is_init=is_init) + h = self.conv_unpatchify(h, is_init=is_init) + h = h.add_(shortcut) + + temb = None + + h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3) + h = self.mid.block_1(h, temb, is_init=is_init) + h = self.mid.attn_1(h) + h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3) + h = self.mid.block_2(h, temb, is_init=is_init) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3) + h = self.up[i_level].block[i_block](h, temb, is_init=is_init) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + if isinstance(self.up[i_level].upsample, Upsample2D) or (hasattr(self.up[i_level].upsample, "module") and isinstance(self.up[i_level].upsample.module, Upsample2D)): + B = h.size(0) + h = h.permute(0,2,3,4,1).flatten(0,1) + h = self.up[i_level].upsample(h) + h = h.unflatten(0, (B, -1)).permute(0,4,1,2,3) + else: + h = self.up[i_level].upsample(h, is_init=is_init) + + # end + h = h.permute(0,2,3,4,1) # b c l h w -> b l h w c + self.norm_out.to(dtype=h.dtype, device=h.device) # To be updated + h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2) + h = self.conv_out(h) + h = h.permute(0,4,1,2,3) + + if is_init: + h = h[:, :, (self.temporal_downsample - 1):] + return h + + + +def rms_norm(input, normalized_shape, eps=1e-6): + dtype = input.dtype + input = input.to(torch.float32) + variance = input.pow(2).flatten(-len(normalized_shape)).mean(-1)[(...,) + (None,) * len(normalized_shape)] + input = input * torch.rsqrt(variance + eps) + return input.to(dtype) + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False, rms_norm_mean=False, only_return_mean=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=-3) #N,[X],C,H,W + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + self.deterministic = deterministic + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, + device=self.parameters.device, + dtype=self.parameters.dtype) + if rms_norm_mean: + self.mean = rms_norm(self.mean, self.mean.size()[1:]) + self.only_return_mean = only_return_mean + + def sample(self, generator=None): + # make sure sample is on the same device + # as the parameters and has same dtype + sample = torch.randn( + self.mean.shape, generator=generator, device=self.parameters.device) + sample = sample.to(dtype=self.parameters.dtype) + x = self.mean + self.std * sample + if self.only_return_mean: + return self.mean + else: + return x + + +class StepVideoVAE(nn.Module): + def __init__(self, + in_channels=3, + out_channels=3, + z_channels=64, + num_res_blocks=2, + model_path=None, + weight_dict={}, + world_size=1, + version=2, + ): + super().__init__() + + self.frame_len = 17 + self.latent_len = 3 if version == 2 else 5 + + base_group_norm.spatial = True if version == 2 else False + + self.encoder = VideoEncoder( + in_channels=in_channels, + z_channels=z_channels, + num_res_blocks=num_res_blocks, + version=version, + ) + + self.decoder = VideoDecoder( + z_channels=z_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + version=version, + ) + + if model_path is not None: + weight_dict = self.init_from_ckpt(model_path) + if len(weight_dict) != 0: + self.load_from_dict(weight_dict) + self.convert_channel_last() + + self.world_size = world_size + + def init_from_ckpt(self, model_path): + from safetensors import safe_open + p = {} + with safe_open(model_path, framework="pt", device="cpu") as f: + for k in f.keys(): + tensor = f.get_tensor(k) + if k.startswith("decoder.conv_out."): + k = k.replace("decoder.conv_out.", "decoder.conv_out.conv.") + p[k] = tensor + return p + + def load_from_dict(self, p): + self.load_state_dict(p) + + def convert_channel_last(self): + #Conv2d NCHW->NHWC + pass + + def naive_encode(self, x, is_init_image=True): + b, l, c, h, w = x.size() + x = rearrange(x, 'b l c h w -> b c l h w').contiguous() + z = self.encoder(x, l, True) # 下采样[1, 4, 8, 16, 16] + return z + + @torch.inference_mode() + def encode(self, x): + # b (nc cf) c h w -> (b nc) cf c h w -> encode -> (b nc) cf c h w -> b (nc cf) c h w + chunks = list(x.split(self.frame_len, dim=1)) + for i in range(len(chunks)): + chunks[i] = self.naive_encode(chunks[i], True) + z = torch.cat(chunks, dim=1) + + posterior = DiagonalGaussianDistribution(z) + return posterior.sample() + + def decode_naive(self, z, is_init=True): + z = z.to(next(self.decoder.parameters()).dtype) + dec = self.decoder(z, is_init) + return dec + + @torch.inference_mode() + def decode(self, z): + # b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w + chunks = list(z.split(self.latent_len, dim=1)) + + if self.world_size > 1: + chunks_total_num = len(chunks) + max_num_per_rank = (chunks_total_num + self.world_size - 1) // self.world_size + rank = torch.distributed.get_rank() + chunks_ = chunks[max_num_per_rank * rank : max_num_per_rank * (rank + 1)] + if len(chunks_) < max_num_per_rank: + chunks_.extend(chunks[:max_num_per_rank-len(chunks_)]) + chunks = chunks_ + + for i in range(len(chunks)): + chunks[i] = self.decode_naive(chunks[i], True).permute(0,2,1,3,4) + x = torch.cat(chunks, dim=1) + + if self.world_size > 1: + x_ = torch.empty([x.size(0), (self.world_size * max_num_per_rank) * self.frame_len, *x.shape[2:]], dtype=x.dtype, device=x.device) + torch.distributed.all_gather_into_tensor(x_, x) + x = x_[:, : chunks_total_num * self.frame_len] + + x = self.mix(x) + return x + + def mix(self, x): + remain_scale = 0.6 + mix_scale = 1. - remain_scale + front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len) + back = slice(self.frame_len, x.size(1), self.frame_len) + x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale + x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale + return x + + @staticmethod + def state_dict_converter(): + return StepVideoVAEStateDictConverter() + + +class StepVideoVAEStateDictConverter: + def __init__(self): + super().__init__() + + def from_diffusers(self, state_dict): + return self.from_civitai(state_dict) + + def from_civitai(self, state_dict): + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith("decoder.conv_out."): + name_ = name.replace("decoder.conv_out.", "decoder.conv_out.conv.") + else: + name_ = name + state_dict_[name_] = param + return state_dict_ diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index 6dc9cd4..9a12372 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -10,4 +10,5 @@ from .cog_video import CogVideoPipeline from .omnigen_image import OmnigenImagePipeline from .pipeline_runner import SDVideoPipelineRunner from .hunyuan_video import HunyuanVideoPipeline +from .step_video import StepVideoPipeline KolorsImagePipeline = SDXLImagePipeline diff --git a/diffsynth/pipelines/step_video.py b/diffsynth/pipelines/step_video.py new file mode 100644 index 0000000..f9e6072 --- /dev/null +++ b/diffsynth/pipelines/step_video.py @@ -0,0 +1,204 @@ +from ..models import ModelManager +from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder +from ..models.stepvideo_text_encoder import STEP1TextEncoder +from ..models.stepvideo_dit import StepVideoModel +from ..models.stepvideo_vae import StepVideoVAE +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline +from ..prompters import StepVideoPrompter +import torch +from einops import rearrange +import numpy as np +from PIL import Image +from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear +from transformers.models.bert.modeling_bert import BertEmbeddings +from ..models.stepvideo_dit import RMSNorm +from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Resnet3DBlock, AttnBlock, Res3DBlockUpsample, Upsample2D + + + +class StepVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(sigma_min=0.0, extra_one_step=True, shift=13.0, reverse_sigmas=True, num_train_timesteps=1) + self.prompter = StepVideoPrompter() + self.text_encoder_1: HunyuanDiTCLIPTextEncoder = None + self.text_encoder_2: STEP1TextEncoder = None + self.dit: StepVideoModel = None + self.vae: StepVideoVAE = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae'] + + + def enable_vram_management(self, num_persistent_param_in_dit=None): + dtype = next(iter(self.text_encoder_1.parameters())).dtype + enable_vram_management( + self.text_encoder_1, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + BertEmbeddings: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=torch.float32, + computation_device=self.device, + ), + ) + dtype = next(iter(self.text_encoder_2.parameters())).dtype + enable_vram_management( + self.text_encoder_2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + torch.nn.Embedding: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.dit.parameters())).dtype + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.vae.parameters())).dtype + enable_vram_management( + self.vae, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + CausalConv: AutoWrappedModule, + CausalConvAfterNorm: AutoWrappedModule, + Resnet3DBlock: AutoWrappedModule, + AttnBlock: AutoWrappedModule, + Res3DBlockUpsample: AutoWrappedModule, + Upsample2D: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + + def fetch_models(self, model_manager: ModelManager): + self.text_encoder_1 = model_manager.fetch_model("hunyuan_dit_clip_text_encoder") + self.text_encoder_2 = model_manager.fetch_model("stepvideo_text_encoder_2") + self.dit = model_manager.fetch_model("stepvideo_dit") + self.vae = model_manager.fetch_model("stepvideo_vae") + self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = StepVideoPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + return pipe + + + def encode_prompt(self, prompt, positive=True): + clip_embeds, llm_embeds, llm_mask = self.prompter.encode_prompt(prompt, device=self.device, positive=positive) + clip_embeds = clip_embeds.to(dtype=self.torch_dtype, device=self.device) + llm_embeds = llm_embeds.to(dtype=self.torch_dtype, device=self.device) + llm_mask = llm_mask.to(dtype=self.torch_dtype, device=self.device) + return {"encoder_hidden_states_2": clip_embeds, "encoder_hidden_states": llm_embeds, "encoder_attention_mask": llm_mask} + + + def tensor2video(self, frames): + frames = rearrange(frames, "T C H W -> T H W C") + frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) + frames = [Image.fromarray(frame) for frame in frames] + return frames + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + input_video=None, + denoising_strength=1.0, + seed=None, + rand_device="cpu", + height=544, + width=992, + num_frames=204, + cfg_scale=9.0, + num_inference_steps=30, + progress_bar_cmd=lambda x: x, + progress_bar_st=None, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Initialize noise + latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device) + + # Encode prompts + self.load_models_to_device(["text_encoder_1", "text_encoder_2"]) + prompt_emb_posi = self.encode_prompt(prompt, positive=True) + if cfg_scale != 1.0: + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) + + # Denoise + self.load_models_to_device(["dit"]) + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(self.device) + print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") + + # Inference + noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi) + if cfg_scale != 1.0: + noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + + # Decode + self.load_models_to_device(['vae']) + frames = self.vae.decode(latents) + self.load_models_to_device([]) + frames = self.tensor2video(frames[0]) + + return frames diff --git a/diffsynth/prompters/__init__.py b/diffsynth/prompters/__init__.py index 1933555..3559f1d 100644 --- a/diffsynth/prompters/__init__.py +++ b/diffsynth/prompters/__init__.py @@ -8,3 +8,4 @@ from .flux_prompter import FluxPrompter from .omost import OmostPromter from .cog_prompter import CogPrompter from .hunyuan_video_prompter import HunyuanVideoPrompter +from .stepvideo_prompter import StepVideoPrompter diff --git a/diffsynth/prompters/stepvideo_prompter.py b/diffsynth/prompters/stepvideo_prompter.py new file mode 100644 index 0000000..79d374b --- /dev/null +++ b/diffsynth/prompters/stepvideo_prompter.py @@ -0,0 +1,56 @@ +from .base_prompter import BasePrompter +from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder +from ..models.stepvideo_text_encoder import STEP1TextEncoder +from transformers import BertTokenizer +import os, torch + + +class StepVideoPrompter(BasePrompter): + + def __init__( + self, + tokenizer_1_path=None, + ): + if tokenizer_1_path is None: + base_path = os.path.dirname(os.path.dirname(__file__)) + tokenizer_1_path = os.path.join( + base_path, "tokenizer_configs/hunyuan_dit/tokenizer") + super().__init__() + self.tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_1_path) + + def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = None, text_encoder_2: STEP1TextEncoder = None): + self.text_encoder_1 = text_encoder_1 + self.text_encoder_2 = text_encoder_2 + + def encode_prompt_using_clip(self, prompt, max_length, device): + text_inputs = self.tokenizer_1( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + prompt_embeds = self.text_encoder_1( + text_inputs.input_ids.to(device), + attention_mask=text_inputs.attention_mask.to(device), + ) + return prompt_embeds + + def encode_prompt_using_llm(self, prompt, max_length, device): + y, y_mask = self.text_encoder_2(prompt, max_length=max_length, device=device) + return y, y_mask + + def encode_prompt(self, + prompt, + positive=True, + device="cuda"): + + prompt = self.process_prompt(prompt, positive=positive) + + clip_embeds = self.encode_prompt_using_clip(prompt, max_length=77, device=device) + llm_embeds, llm_mask = self.encode_prompt_using_llm(prompt, max_length=320, device=device) + + llm_mask = torch.nn.functional.pad(llm_mask, (clip_embeds.shape[1], 0), value=1) + + return clip_embeds, llm_embeds, llm_mask diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 81a6fac..aea6757 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -4,13 +4,14 @@ import torch class FlowMatchScheduler(): - def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False): + def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): self.num_train_timesteps = num_train_timesteps self.shift = shift self.sigma_max = sigma_max self.sigma_min = sigma_min self.inverse_timesteps = inverse_timesteps self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas self.set_timesteps(num_inference_steps) @@ -23,6 +24,8 @@ class FlowMatchScheduler(): if self.inverse_timesteps: self.sigmas = torch.flip(self.sigmas, dims=[0]) self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas self.timesteps = self.sigmas * self.num_train_timesteps if training: x = self.timesteps @@ -38,7 +41,7 @@ class FlowMatchScheduler(): timestep_id = torch.argmin((self.timesteps - timestep).abs()) sigma = self.sigmas[timestep_id] if to_final or timestep_id + 1 >= len(self.timesteps): - sigma_ = 1 if self.inverse_timesteps else 0 + sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 else: sigma_ = self.sigmas[timestep_id + 1] prev_sample = sample + model_output * (sigma_ - sigma) diff --git a/examples/stepvideo/README.md b/examples/stepvideo/README.md new file mode 100644 index 0000000..8b14ae8 --- /dev/null +++ b/examples/stepvideo/README.md @@ -0,0 +1,13 @@ +# Stepvideo + +StepVideo is a state-of-the-art (SoTA) text-to-video pre-trained model with 30 billion parameters and the capability to generate videos up to 204 frames. + +* Model: https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary +* GitHub: https://github.com/stepfun-ai/Step-Video-T2V +* Technical report: https://arxiv.org/abs/2502.10248 + +## Examples + +See [`./stepvideo_text_to_video.py`](./stepvideo_text_to_video.py). + +https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b diff --git a/examples/stepvideo/stepvideo_text_to_video.py b/examples/stepvideo/stepvideo_text_to_video.py new file mode 100644 index 0000000..aaa2b16 --- /dev/null +++ b/examples/stepvideo/stepvideo_text_to_video.py @@ -0,0 +1,47 @@ +from modelscope import snapshot_download +from diffsynth import ModelManager, StepVideoPipeline, save_video +import torch + + +# Download models +snapshot_download(model_id="stepfun-ai/stepvideo-t2v", cache_dir="models") + +# Load the compiled attention for the LLM text encoder. +# If you encounter errors here. Please select other compiled file that matches your environment or delete this line. +torch.ops.load_library("models/stepfun-ai/stepvideo-t2v/lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so") + +# Load models +model_manager = ModelManager() +model_manager.load_models( + ["models/stepvideo-t2v/hunyuan_clip/clip_text_encoder/pytorch_model.bin"], + torch_dtype=torch.float32, device="cpu" +) +model_manager.load_models( + [ + "models/stepfun-ai/stepvideo-t2v/step_llm", + "models/stepfun-ai/stepvideo-t2v/vae/vae_v2.safetensors", + [ + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00006-of-00006.safetensors", + ] + ], + torch_dtype=torch.bfloat16, device="cpu" +) +pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") + +# Enable VRAM management +# This model requires 80G VRAM. +# In order to reduce VRAM required, please set `num_persistent_param_in_dit` to a small number. +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Run! +video = pipe( + prompt="一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光。超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。", + negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。", + num_inference_steps=30, cfg_scale=9, num_frames=204, seed=1 +) +save_video(video, "video.mp4", fps=25, quality=5)