diff --git a/diffsynth/data/video.py b/diffsynth/data/video.py index 16e1918..8eafa66 100644 --- a/diffsynth/data/video.py +++ b/diffsynth/data/video.py @@ -135,8 +135,8 @@ class VideoData: frame.save(os.path.join(folder, f"{i}.png")) -def save_video(frames, save_path, fps, quality=9): - writer = imageio.get_writer(save_path, fps=fps, quality=quality) +def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params) for frame in tqdm(frames, desc="Saving video"): frame = np.array(frame) writer.append_data(frame) diff --git a/diffsynth/models/stepvideo_vae.py b/diffsynth/models/stepvideo_vae.py index ba46cac..db244c0 100644 --- a/diffsynth/models/stepvideo_vae.py +++ b/diffsynth/models/stepvideo_vae.py @@ -14,6 +14,19 @@ import torch from einops import rearrange from torch import nn from torch.nn import functional as F +from tqdm import tqdm +from einops import repeat + + +class BaseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels): + super().__init__(num_groups=num_groups, num_channels=num_channels) + + def forward(self, x, zero_pad=False, **kwargs): + if zero_pad: + return base_group_norm_with_zero_pad(x, self, **kwargs) + else: + return base_group_norm(x, self, **kwargs) def base_group_norm(x, norm_layer, act_silu=False, channel_last=False): @@ -456,14 +469,14 @@ class AttnBlock(nn.Module): ): super().__init__() - self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels) + self.norm = BaseGroupNorm(num_groups=32, num_channels=in_channels) self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) def attention(self, x, is_init=True): - x = base_group_norm(x, self.norm, act_silu=False, channel_last=True) + x = self.norm(x, act_silu=False, channel_last=True) q = self.q(x, is_init) k = self.k(x, is_init) v = self.v(x, is_init) @@ -495,12 +508,12 @@ class Resnet3DBlock(nn.Module): out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels - self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels) + self.norm1 = BaseGroupNorm(num_groups=32, num_channels=in_channels) self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3) if temb_channels > 0: self.temb_proj = nn.Linear(temb_channels, out_channels) - self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels) + self.norm2 = BaseGroupNorm(num_groups=32, num_channels=out_channels) self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3) assert conv_shortcut is False @@ -514,14 +527,14 @@ class Resnet3DBlock(nn.Module): def forward(self, x, temb=None, is_init=True): x = x.permute(0,2,3,4,1).contiguous() - h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2) + h = self.norm1(x, zero_pad=True, act_silu=True, pad_size=2) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None] x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x - h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2) + h = self.norm2(h, zero_pad=True, act_silu=True, pad_size=2) x = self.conv2(h, residual=x) x = x.permute(0,4,1,2,3) @@ -675,10 +688,10 @@ class Res3DBlockUpsample(nn.Module): self.act_ = nn.SiLU(inplace=True) self.conv1 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3]) - self.norm1 = nn.GroupNorm(32, num_filters) + self.norm1 = BaseGroupNorm(32, num_filters) self.conv2 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3]) - self.norm2 = nn.GroupNorm(32, num_filters) + self.norm2 = BaseGroupNorm(32, num_filters) self.down_sampling = down_sampling if down_sampling: @@ -688,7 +701,7 @@ class Res3DBlockUpsample(nn.Module): if num_filters != input_filters or down_sampling: self.conv3 = CausalConvChannelLast(input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride) - self.norm3 = nn.GroupNorm(32, num_filters) + self.norm3 = BaseGroupNorm(32, num_filters) def forward(self, x, is_init=False): x = x.permute(0,2,3,4,1).contiguous() @@ -696,14 +709,14 @@ class Res3DBlockUpsample(nn.Module): residual = x h = self.conv1(x, is_init) - h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True) + h = self.norm1(h, act_silu=True, channel_last=True) h = self.conv2(h, is_init) - h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True) + h = self.norm2(h, act_silu=False, channel_last=True) if self.down_sampling or self.num_filters != self.input_filters: x = self.conv3(x, is_init) - x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True) + x = self.norm3(x, act_silu=False, channel_last=True) h.add_(x) h = self.act_(h) @@ -973,7 +986,7 @@ class StepVideoVAE(nn.Module): return dec @torch.inference_mode() - def decode(self, z): + def decode_original(self, z): # b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w chunks = list(z.split(self.latent_len, dim=1)) @@ -998,15 +1011,104 @@ class StepVideoVAE(nn.Module): x = self.mix(x) return x - def mix(self, x): - remain_scale = 0.6 + def mix(self, x, smooth_scale = 0.6): + remain_scale = smooth_scale mix_scale = 1. - remain_scale front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len) back = slice(self.frame_len, x.size(1), self.frame_len) - x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale - x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale + x[:, front], x[:, back] = ( + x[:, front] * remain_scale + x[:, back] * mix_scale, + x[:, back] * remain_scale + x[:, front] * mix_scale + ) return x + def single_decode(self, hidden_states, device): + chunks = list(hidden_states.split(self.latent_len, dim=1)) + for i in range(len(chunks)): + chunks[i] = self.decode_naive(chunks[i].to(device), True).permute(0,2,1,3,4).cpu() + x = torch.cat(chunks, dim=1) + return x + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + def tiled_decode(self, hidden_states, device, tile_size=(34, 34), tile_stride=(16, 16)): + B, T, C, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for t in range(0, T, 3): + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + t_, h_, w_ = t + 3, h + size_h, w + size_w + tasks.append((t, t_, h, h_, w, w_)) + + # Run + data_device = "cpu" + computation_device = device + + weight = torch.zeros((1, 1, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((B, 3, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device) + + for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, t:t_, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.decode_naive(hidden_states_batch, True).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * 16, (size_w - stride_w) * 16) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_t = t // 3 * 17 + target_h = h * 16 + target_w = w * 16 + values[ + :, + :, + target_t: target_t + hidden_states_batch.shape[2], + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + target_t: target_t + hidden_states_batch.shape[2], + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + return values / weight + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(16, 16), smooth_scale=0.6): + hidden_states = hidden_states.to("cpu") + if tiled: + video = self.tiled_decode(hidden_states, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_states, device) + video = self.mix(video, smooth_scale=smooth_scale) + return video + @staticmethod def state_dict_converter(): return StepVideoVAEStateDictConverter() diff --git a/diffsynth/pipelines/step_video.py b/diffsynth/pipelines/step_video.py index c2dd463..5614017 100644 --- a/diffsynth/pipelines/step_video.py +++ b/diffsynth/pipelines/step_video.py @@ -13,7 +13,7 @@ from PIL import Image from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from transformers.models.bert.modeling_bert import BertEmbeddings from ..models.stepvideo_dit import RMSNorm -from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Resnet3DBlock, AttnBlock, Res3DBlockUpsample, Upsample2D +from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm @@ -100,10 +100,8 @@ class StepVideoPipeline(BasePipeline): torch.nn.Conv3d: AutoWrappedModule, CausalConv: AutoWrappedModule, CausalConvAfterNorm: AutoWrappedModule, - Resnet3DBlock: AutoWrappedModule, - AttnBlock: AutoWrappedModule, - Res3DBlockUpsample: AutoWrappedModule, Upsample2D: AutoWrappedModule, + BaseGroupNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, @@ -143,7 +141,7 @@ class StepVideoPipeline(BasePipeline): def tensor2video(self, frames): - frames = rearrange(frames, "T C H W -> T H W C") + frames = rearrange(frames, "C T H W -> T H W C") frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) frames = [Image.fromarray(frame) for frame in frames] return frames @@ -163,9 +161,16 @@ class StepVideoPipeline(BasePipeline): num_frames=204, cfg_scale=9.0, num_inference_steps=30, + tiled=True, + tile_size=(34, 34), + tile_stride=(16, 16), + smooth_scale=0.6, progress_bar_cmd=lambda x: x, progress_bar_st=None, ): + # Tiler parameters + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength) @@ -197,7 +202,7 @@ class StepVideoPipeline(BasePipeline): # Decode self.load_models_to_device(['vae']) - frames = self.vae.decode(latents) + frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs) self.load_models_to_device([]) frames = self.tensor2video(frames[0]) diff --git a/examples/stepvideo/README.md b/examples/stepvideo/README.md index 261dcdb..ad2f404 100644 --- a/examples/stepvideo/README.md +++ b/examples/stepvideo/README.md @@ -10,6 +10,8 @@ StepVideo is a state-of-the-art (SoTA) text-to-video pre-trained model with 30 b For original BF16 version, please see [`./stepvideo_text_to_video.py`](./stepvideo_text_to_video.py). 80G VRAM required. +We also support auto-offload, which can reduce the VRAM requirement to **24GB**; however, it requires 2x time for inference. Please see [`./stepvideo_text_to_video_low_vram.py`](./stepvideo_text_to_video_low_vram.py). + https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b For FP8 quantized version, please see [`./stepvideo_text_to_video_quantized.py`](./stepvideo_text_to_video_quantized.py). 40G VRAM required. diff --git a/examples/stepvideo/stepvideo_text_to_video.py b/examples/stepvideo/stepvideo_text_to_video.py index 937f5d1..302ed08 100644 --- a/examples/stepvideo/stepvideo_text_to_video.py +++ b/examples/stepvideo/stepvideo_text_to_video.py @@ -44,4 +44,7 @@ video = pipe( negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。", num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1 ) -save_video(video, "video.mp4", fps=25, quality=5) +save_video( + video, "video.mp4", fps=25, quality=5, + ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"] +) diff --git a/examples/stepvideo/stepvideo_text_to_video_low_vram.py b/examples/stepvideo/stepvideo_text_to_video_low_vram.py new file mode 100644 index 0000000..95d4557 --- /dev/null +++ b/examples/stepvideo/stepvideo_text_to_video_low_vram.py @@ -0,0 +1,54 @@ +from modelscope import snapshot_download +from diffsynth import ModelManager, StepVideoPipeline, save_video +import torch + + +# Download models +snapshot_download(model_id="stepfun-ai/stepvideo-t2v", cache_dir="models") + +# Load the compiled attention for the LLM text encoder. +# If you encounter errors here. Please select other compiled file that matches your environment or delete this line. +torch.ops.load_library("models/stepfun-ai/stepvideo-t2v/lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so") + +# Load models +model_manager = ModelManager() +model_manager.load_models( + ["models/stepfun-ai/stepvideo-t2v/hunyuan_clip/clip_text_encoder/pytorch_model.bin"], + torch_dtype=torch.float32, device="cpu" +) +model_manager.load_models( + [ + "models/stepfun-ai/stepvideo-t2v/step_llm", + [ + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00006-of-00006.safetensors", + ] + ], + torch_dtype=torch.float8_e4m3fn, device="cpu" +) +model_manager.load_models( + ["models/stepfun-ai/stepvideo-t2v/vae/vae_v2.safetensors"], + torch_dtype=torch.bfloat16, device="cpu" +) +pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") + +# Enable VRAM management +# This model requires 24G VRAM. +# In order to speed up, please set `num_persistent_param_in_dit` to a large number or None (unlimited). +pipe.enable_vram_management(num_persistent_param_in_dit=0) + +# Run! +video = pipe( + prompt="一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光。超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。", + negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。", + num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1, + tiled=True, tile_size=(34, 34), tile_stride=(16, 16) +) +save_video( + video, "video.mp4", fps=25, quality=5, + ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"] +) diff --git a/examples/stepvideo/stepvideo_text_to_video_quantized.py b/examples/stepvideo/stepvideo_text_to_video_quantized.py index 734b2fa..7868eb1 100644 --- a/examples/stepvideo/stepvideo_text_to_video_quantized.py +++ b/examples/stepvideo/stepvideo_text_to_video_quantized.py @@ -37,7 +37,7 @@ model_manager.load_models( pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") # Enable VRAM management -# This model requires 80G VRAM. +# This model requires 40G VRAM. # In order to reduce VRAM required, please set `num_persistent_param_in_dit` to a small number. pipe.enable_vram_management(num_persistent_param_in_dit=None) @@ -47,4 +47,7 @@ video = pipe( negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。", num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1 ) -save_video(video, "video.mp4", fps=25, quality=5) +save_video( + video, "video.mp4", fps=25, quality=5, + ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"] +)