diff --git a/README.md b/README.md index cdf45ca..9d85ead 100644 --- a/README.md +++ b/README.md @@ -236,7 +236,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| - +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| @@ -387,6 +387,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History +- **October 30, 2025**: We support [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which enables text-to-video, image-to-video, and video continuation capabilities. This model adopts Wan's framework for both inference and training in this project. + - **October 27, 2025**: We support [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, further expanding Wan's ecosystem. - **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) is released! This model is jointly developed and open-sourced by us and the Taobao Design Team. The model is built upon Qwen-Image, specifically designed for e-commerce poster scenarios, and supports precise partition layout control. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py). diff --git a/README_zh.md b/README_zh.md index 3d64abd..d6dc241 100644 --- a/README_zh.md +++ b/README_zh.md @@ -236,7 +236,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)| - +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)| @@ -403,6 +403,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 +- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。 + - **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。 - **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。 diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index c932ec1..47e26e0 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -80,6 +80,8 @@ from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel + model_loader_configs = [ # These configs are provided for detecting model type automatically. # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource) @@ -159,6 +161,7 @@ model_loader_configs = [ (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"), + (None, "8b27900f680d7251ce44e2dc8ae1ffef", ["wan_video_dit"], [LongCatVideoTransformer3DModel], "civitai"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py new file mode 100644 index 0000000..bc4e79d --- /dev/null +++ b/diffsynth/models/longcat_video_dit.py @@ -0,0 +1,901 @@ +from typing import List, Optional, Tuple + +import math +import torch +import torch.nn as nn +import torch.amp as amp + +import numpy as np +import torch.nn.functional as F +from einops import rearrange, repeat +from .wan_video_dit import flash_attention +from ..vram_management import gradient_checkpoint_forward + + +class RMSNorm_FP32(torch.nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class RotaryPositionalEmbedding(nn.Module): + + def __init__(self, + head_dim, + cp_split_hw=None + ): + """Rotary positional embedding for 3D + Reference : https://blog.eleuther.ai/rotary-embeddings/ + Paper: https://arxiv.org/pdf/2104.09864.pdf + Args: + dim: Dimension of embedding + base: Base value for exponential + """ + super().__init__() + self.head_dim = head_dim + assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.' + self.cp_split_hw = cp_split_hw + # We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels + self.base = 10000 + self.freqs_dict = {} + + def register_grid_size(self, grid_size): + if grid_size not in self.freqs_dict: + self.freqs_dict.update({ + grid_size: self.precompute_freqs_cis_3d(grid_size) + }) + + def precompute_freqs_cis_3d(self, grid_size): + num_frames, height, width = grid_size + dim_t = self.head_dim - 4 * (self.head_dim // 6) + dim_h = 2 * (self.head_dim // 6) + dim_w = 2 * (self.head_dim // 6) + freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) + freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) + freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) + grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32) + grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32) + grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32) + grid_t = torch.from_numpy(grid_t).float() + grid_h = torch.from_numpy(grid_h).float() + grid_w = torch.from_numpy(grid_w).float() + freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) + freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) + freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) + freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + # (T H W D) + freqs = rearrange(freqs, "T H W D -> (T H W) D") + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # with torch.no_grad(): + # freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width) + # freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw) + # freqs = rearrange(freqs, "T H W D -> (T H W) D") + + return freqs + + def forward(self, q, k, grid_size): + """3D RoPE. + + Args: + query: [B, head, seq, head_dim] + key: [B, head, seq, head_dim] + Returns: + query and key with the same shape as input. + """ + + if grid_size not in self.freqs_dict: + self.register_grid_size(grid_size) + + freqs_cis = self.freqs_dict[grid_size].to(q.device) + q_, k_ = q.float(), k.float() + freqs_cis = freqs_cis.float().to(q.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + q_ = (q_ * cos) + (rotate_half(q_) * sin) + k_ = (k_ * cos) + (rotate_half(k_) * sin) + + return q_.type_as(q), k_.type_as(k) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = None, + cp_split_hw: Optional[List[int]] = None + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + self.enable_bsa = enable_bsa + self.bsa_params = bsa_params + self.cp_split_hw = cp_split_hw + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.proj = nn.Linear(dim, dim) + + self.rope_3d = RotaryPositionalEmbedding( + self.head_dim, + cp_split_hw=cp_split_hw + ) + + def _process_attn(self, q, k, v, shape): + q = rearrange(q, "B H S D -> B S (H D)") + k = rearrange(k, "B H S D -> B S (H D)") + v = rearrange(v, "B H S D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads) + return x + + def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if return_kv: + k_cache, v_cache = k.clone(), v.clone() + + q, k = self.rope_3d(q, k, shape) + + # cond mode + if num_cond_latents is not None and num_cond_latents > 0: + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + # process the condition tokens + q_cond = q[:, :, :num_cond_latents_thw].contiguous() + k_cond = k[:, :, :num_cond_latents_thw].contiguous() + v_cond = v[:, :, :num_cond_latents_thw].contiguous() + x_cond = self._process_attn(q_cond, k_cond, v_cond, shape) + # process the noise tokens + q_noise = q[:, :, num_cond_latents_thw:].contiguous() + x_noise = self._process_attn(q_noise, k, v, shape) + # merge x_cond and x_noise + x = torch.cat([x_cond, x_noise], dim=2).contiguous() + else: + x = self._process_attn(q, k, v, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + if return_kv: + return x, (k_cache, v_cache) + else: + return x + + def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: + """ + """ + B, N, C = x.shape + qkv = self.qkv(x) + + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D] + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + T, H, W = shape + k_cache, v_cache = kv_cache + assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B] + if k_cache.shape[0] == 1: + k_cache = k_cache.repeat(B, 1, 1, 1) + v_cache = v_cache.repeat(B, 1, 1, 1) + + if num_cond_latents is not None and num_cond_latents > 0: + k_full = torch.cat([k_cache, k], dim=2).contiguous() + v_full = torch.cat([v_cache, v], dim=2).contiguous() + q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous() + q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) + q = q_padding[:, :, -N:].contiguous() + + x = self._process_attn(q, k_full, v_full, shape) + + x_output_shape = (B, N, C) + x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D] + x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C] + x = self.proj(x) + + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads, + enable_flashattn3=False, + enable_flashattn2=False, + enable_xformers=False, + ): + super(MultiHeadCrossAttention, self).__init__() + assert dim % num_heads == 0, "d_model must be divisible by num_heads" + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = nn.Linear(dim, dim) + self.kv_linear = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + + self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) + + self.enable_flashattn3 = enable_flashattn3 + self.enable_flashattn2 = enable_flashattn2 + self.enable_xformers = enable_xformers + + def _process_cross_attn(self, x, cond, kv_seqlen): + B, N, C = x.shape + assert C == self.dim and cond.shape[2] == self.dim + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + q, k = self.q_norm(q), self.k_norm(k) + + q = rearrange(q, "B S H D -> B S (H D)") + k = rearrange(k, "B S H D -> B S (H D)") + v = rearrange(v, "B S H D -> B S (H D)") + x = flash_attention(q, k, v, num_heads=self.num_heads) + + x = x.view(B, -1, C) + x = self.proj(x) + return x + + def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): + """ + x: [B, N, C] + cond: [B, M, C] + """ + if num_cond_latents is None or num_cond_latents == 0: + return self._process_cross_attn(x, cond, kv_seqlen) + else: + B, N, C = x.shape + if num_cond_latents is not None and num_cond_latents > 0: + assert shape is not None, "SHOULD pass in the shape" + num_cond_latents_thw = num_cond_latents * (N // shape[0]) + x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C] + output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C] + output = torch.cat([ + torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device), + output_noise + ], dim=1).contiguous() + else: + raise NotImplementedError + + return output + + +class LayerNorm_FP32(nn.LayerNorm): + def __init__(self, dim, eps, elementwise_affine): + super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + out = F.layer_norm( + inputs.float(), + self.normalized_shape, + None if self.weight is None else self.weight.float(), + None if self.bias is None else self.bias.float() , + self.eps + ).to(origin_dtype) + return out + + +def modulate_fp32(norm_func, x, shift, scale): + # Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D) + # ensure the modulation params be fp32 + assert shift.dtype == torch.float32, scale.dtype == torch.float32 + dtype = x.dtype + x = norm_func(x.to(torch.float32)) + x = x * (scale + 1) + shift + x = x.to(dtype) + return x + + +class FinalLayer_FP32(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim): + super().__init__() + self.hidden_size = hidden_size + self.num_patch = num_patch + self.out_channels = out_channels + self.adaln_tembed_dim = adaln_tembed_dim + + self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True)) + + def forward(self, x, t, latent_shape): + # timestep shape: [B, T, C] + assert t.dtype == torch.float32 + B, N, C = x.shape + T, _, _ = latent_shape + + with amp.autocast('cuda', dtype=torch.float32): + shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] + x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) + x = self.linear(x) + return x + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.dim = dim + self.hidden_dim = hidden_dim + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, t_embed_dim, frequency_embedding_size=256): + super().__init__() + self.t_embed_dim = t_embed_dim + self.frequency_embedding_size = frequency_embedding_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, t_embed_dim, bias=True), + nn.SiLU(), + nn.Linear(t_embed_dim, t_embed_dim, bias=True), + ) + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, in_channels, hidden_size): + super().__init__() + self.in_channels = in_channels + self.hidden_size = hidden_size + self.y_proj = nn.Sequential( + nn.Linear(in_channels, hidden_size, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + + def forward(self, caption): + B, _, N, C = caption.shape + caption = self.y_proj(caption) + return caption + + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + B, C, T, H, W = x.shape + x = self.proj(x) # (B C T H W) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class LongCatSingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: int, + adaln_tembed_dim: int, + enable_flashattn3: bool = False, + enable_flashattn2: bool = False, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params=None, + cp_split_hw=None + ): + super().__init__() + + self.hidden_size = hidden_size + + # scale and gate modulation + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True) + ) + + self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) + self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True) + + self.attn = Attention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + self.cross_attn = MultiHeadCrossAttention( + dim=hidden_size, + num_heads=num_heads, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + ) + self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio)) + + def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False): + """ + x: [B, N, C] + y: [1, N_valid_tokens, C] + t: [B, T, C_t] + y_seqlen: [B]; type of a list + latent_shape: latent shape of a single item + """ + x_dtype = x.dtype + + B, N, C = x.shape + T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W. + + # compute modulation params in fp32 + with amp.autocast(device_type='cuda', dtype=torch.float32): + shift_msa, scale_msa, gate_msa, \ + shift_mlp, scale_mlp, gate_mlp = \ + self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] + + # self attn with modulation + x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C) + + if kv_cache is not None: + kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device)) + attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache) + else: + attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv) + + if return_kv: + x_s, kv_cache = attn_outputs + else: + x_s = attn_outputs + + with amp.autocast(device_type='cuda', dtype=torch.float32): + x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + # cross attn + if not skip_crs_attn: + if kv_cache is not None: + num_cond_latents = None + x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape) + + # ffn with modulation + x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) + x_s = self.ffn(x_m) + with amp.autocast(device_type='cuda', dtype=torch.float32): + x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] + x = x.to(x_dtype) + + if return_kv: + return x, kv_cache + else: + return x + + +class LongCatVideoTransformer3DModel(torch.nn.Module): + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + hidden_size: int = 4096, + depth: int = 48, + num_heads: int = 32, + caption_channels: int = 4096, + mlp_ratio: int = 4, + adaln_tembed_dim: int = 512, + frequency_embedding_size: int = 256, + # default params + patch_size: Tuple[int] = (1, 2, 2), + # attention config + enable_flashattn3: bool = False, + enable_flashattn2: bool = True, + enable_xformers: bool = False, + enable_bsa: bool = False, + bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]}, + cp_split_hw: Optional[List[int]] = [1, 1], + text_tokens_zero_pad: bool = True, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.cp_split_hw = cp_split_hw + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + ) + + self.blocks = nn.ModuleList( + [ + LongCatSingleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + adaln_tembed_dim=adaln_tembed_dim, + enable_flashattn3=enable_flashattn3, + enable_flashattn2=enable_flashattn2, + enable_xformers=enable_xformers, + enable_bsa=enable_bsa, + bsa_params=bsa_params, + cp_split_hw=cp_split_hw + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer_FP32( + hidden_size, + np.prod(self.patch_size), + out_channels, + adaln_tembed_dim, + ) + + self.gradient_checkpointing = False + self.text_tokens_zero_pad = text_tokens_zero_pad + + self.lora_dict = {} + self.active_loras = [] + + def enable_loras(self, lora_key_list=[]): + self.disable_all_loras() + + module_loras = {} # {module_name: [lora1, lora2, ...]} + model_device = next(self.parameters()).device + model_dtype = next(self.parameters()).dtype + + for lora_key in lora_key_list: + if lora_key in self.lora_dict: + for lora in self.lora_dict[lora_key].loras: + lora.to(model_device, dtype=model_dtype, non_blocking=True) + module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".") + if module_name not in module_loras: + module_loras[module_name] = [] + module_loras[module_name].append(lora) + self.active_loras.append(lora_key) + + for module_name, loras in module_loras.items(): + module = self._get_module_by_name(module_name) + if not hasattr(module, 'org_forward'): + module.org_forward = module.forward + module.forward = self._create_multi_lora_forward(module, loras) + + def _create_multi_lora_forward(self, module, loras): + def multi_lora_forward(x, *args, **kwargs): + weight_dtype = x.dtype + org_output = module.org_forward(x, *args, **kwargs) + + total_lora_output = 0 + for lora in loras: + if lora.use_lora: + lx = lora.lora_down(x.to(lora.lora_down.weight.dtype)) + lx = lora.lora_up(lx) + lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale + total_lora_output += lora_output + + return org_output + total_lora_output + + return multi_lora_forward + + def _get_module_by_name(self, module_name): + try: + module = self + for part in module_name.split('.'): + module = getattr(module, part) + return module + except AttributeError as e: + raise ValueError(f"Cannot find module: {module_name}, error: {e}") + + def disable_all_loras(self): + for name, module in self.named_modules(): + if hasattr(module, 'org_forward'): + module.forward = module.org_forward + delattr(module, 'org_forward') + + for lora_key, lora_network in self.lora_dict.items(): + for lora in lora_network.loras: + lora.to("cpu") + + self.active_loras.clear() + + def enable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = True + + def disable_bsa(self,): + for block in self.blocks: + block.attn.enable_bsa = False + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states, + encoder_attention_mask=None, + num_cond_latents=0, + return_kv=False, + kv_cache_dict={}, + skip_crs_attn=False, + offload_kv_cache=False, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + + B, _, T, H, W = hidden_states.shape + + N_t = T // self.patch_size[0] + N_h = H // self.patch_size[1] + N_w = W // self.patch_size[2] + + assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension." + + # expand the shape of timestep from [B] to [B, T] + if len(timestep.shape) == 1: + timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T] + timestep[:, :num_cond_latents] = 0 + + dtype = hidden_states.dtype + hidden_states = hidden_states.to(dtype) + timestep = timestep.to(dtype) + encoder_hidden_states = encoder_hidden_states.to(dtype) + + hidden_states = self.x_embedder(hidden_states) # [B, N, C] + + with amp.autocast(device_type='cuda', dtype=torch.float32): + t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] + + encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C] + + if self.text_tokens_zero_pad and encoder_attention_mask is not None: + encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None] + encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype) + + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1) + encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C] + y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B] + else: + y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0] + encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1]) + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w) + # hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw) + # hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C") + + # blocks + kv_cache_dict_ret = {} + for i, block in enumerate(self.blocks): + block_outputs = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + x=hidden_states, + y=encoder_hidden_states, + t=t, + y_seqlen=y_seqlens, + latent_shape=(N_t, N_h, N_w), + num_cond_latents=num_cond_latents, + return_kv=return_kv, + kv_cache=kv_cache_dict.get(i, None), + skip_crs_attn=skip_crs_attn, + ) + + if return_kv: + hidden_states, kv_cache = block_outputs + if offload_kv_cache: + kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu()) + else: + kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous()) + else: + hidden_states = block_outputs + + hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out] + + # if self.cp_split_hw[0] * self.cp_split_hw[1] > 1: + # hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw) + + hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W] + + # cast to float32 for better accuracy + hidden_states = hidden_states.to(torch.float32) + + if return_kv: + return hidden_states, kv_cache_dict_ret + else: + return hidden_states + + + def unpatchify(self, x, N_t, N_h, N_w): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + @staticmethod + def state_dict_converter(): + return LongCatVideoTransformer3DModelDictConverter() + + +class LongCatVideoTransformer3DModelDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict + diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 141660f..d374afd 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -22,6 +22,7 @@ from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_vace import VaceWanModel from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_animate_adapter import WanAnimateAdapter +from ..models.longcat_video_dit import LongCatVideoTransformer3DModel from ..schedulers.flow_match import FlowMatchScheduler from ..prompters import WanPrompter from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm @@ -71,6 +72,7 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_UnifiedSequenceParallel(), WanVideoUnit_TeaCache(), WanVideoUnit_CfgMerger(), + WanVideoUnit_LongCatVideo(), ] self.post_units = [ WanVideoPostUnit_S2V(), @@ -150,6 +152,7 @@ class WanVideoPipeline(BasePipeline): vram_limit=vram_limit, ) if self.dit is not None: + from ..models.longcat_video_dit import LayerNorm_FP32, RMSNorm_FP32 dtype = next(iter(self.dit.parameters())).dtype device = "cpu" if vram_limit is not None else self.device enable_vram_management( @@ -162,6 +165,8 @@ class WanVideoPipeline(BasePipeline): torch.nn.Conv2d: AutoWrappedModule, torch.nn.Conv1d: AutoWrappedModule, torch.nn.Embedding: AutoWrappedModule, + LayerNorm_FP32: AutoWrappedModule, + RMSNorm_FP32: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, @@ -467,6 +472,8 @@ class WanVideoPipeline(BasePipeline): sigma_shift: Optional[float] = 5.0, # Speed control motion_bucket_id: Optional[int] = None, + # LongCat-Video + longcat_video: Optional[list[Image.Image]] = None, # VAE tiling tiled: Optional[bool] = True, tile_size: Optional[tuple[int, int]] = (30, 52), @@ -504,6 +511,7 @@ class WanVideoPipeline(BasePipeline): "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, "sigma_shift": sigma_shift, "motion_bucket_id": motion_bucket_id, + "longcat_video": longcat_video, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, @@ -1151,6 +1159,22 @@ class WanVideoPostUnit_AnimateInpaint(PipelineUnit): return {"y": y} +class WanVideoUnit_LongCatVideo(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("longcat_video",), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, longcat_video): + if longcat_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + longcat_video = pipe.preprocess_video(longcat_video) + longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"longcat_latents": longcat_latents} + + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.num_inference_steps = num_inference_steps @@ -1279,6 +1303,7 @@ def model_fn_wan_video( motion_bucket_id: Optional[torch.Tensor] = None, pose_latents=None, face_pixel_values=None, + longcat_latents=None, sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, cfg_merge: bool = False, @@ -1313,6 +1338,18 @@ def model_fn_wan_video( tensor_names=["latents", "y"], batch_size=2 if cfg_merge else 1 ) + # LongCat-Video + if isinstance(dit, LongCatVideoTransformer3DModel): + return model_fn_longcat_video( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + longcat_latents=longcat_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + # wan2.2 s2v if audio_embeds is not None: return model_fn_wans2v( @@ -1468,6 +1505,36 @@ def model_fn_wan_video( return x +def model_fn_longcat_video( + dit: LongCatVideoTransformer3DModel, + latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + longcat_latents: torch.Tensor = None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, +): + if longcat_latents is not None: + latents[:, :, :longcat_latents.shape[2]] = longcat_latents + num_cond_latents = longcat_latents.shape[2] + else: + num_cond_latents = 0 + context = context.unsqueeze(0) + encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64) + output = dit( + latents, + timestep, + context, + encoder_attention_mask, + num_cond_latents=num_cond_latents, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + output = -output + output = output.to(latents.dtype) + return output + + def model_fn_wans2v( dit, latents, diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 1e77970..d4c30cb 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -77,7 +77,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./model_inference/krea-realtime-video.py)|[code](./model_training/full/krea-realtime-video.sh)|[code](./model_training/validate_full/krea-realtime-video.py)|[code](./model_training/lora/krea-realtime-video.sh)|[code](./model_training/validate_lora/krea-realtime-video.py)| - +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./model_inference/LongCat-Video.py)|[code](./model_training/full/LongCat-Video.sh)|[code](./model_training/validate_full/LongCat-Video.py)|[code](./model_training/lora/LongCat-Video.sh)|[code](./model_training/validate_lora/LongCat-Video.py)| ## Model Inference diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 20f380e..79ee0d7 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -77,6 +77,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./model_inference/krea-realtime-video.py)|[code](./model_training/full/krea-realtime-video.sh)|[code](./model_training/validate_full/krea-realtime-video.py)|[code](./model_training/lora/krea-realtime-video.sh)|[code](./model_training/validate_lora/krea-realtime-video.py)| +|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./model_inference/LongCat-Video.py)|[code](./model_training/full/LongCat-Video.sh)|[code](./model_training/validate_full/LongCat-Video.py)|[code](./model_training/lora/LongCat-Video.sh)|[code](./model_training/validate_lora/LongCat-Video.py)| ## 模型推理 diff --git a/examples/wanvideo/model_inference/LongCat-Video.py b/examples/wanvideo/model_inference/LongCat-Video.py new file mode 100644 index 0000000..8df1ec8 --- /dev/null +++ b/examples/wanvideo/model_inference/LongCat-Video.py @@ -0,0 +1,35 @@ +import torch +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.", + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + seed=0, tiled=True, num_frames=93, + cfg_scale=2, sigma_shift=1, +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Video-continuation (The number of frames in `longcat_video` should be 4n+1.) +longcat_video = video[-17:] +video = pipe( + prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.", + negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + seed=1, tiled=True, num_frames=93, + cfg_scale=2, sigma_shift=1, + longcat_video=longcat_video, +) +save_video(video, "video2.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/full/LongCat-Video.sh b/examples/wanvideo/model_training/full/LongCat-Video.sh new file mode 100644 index 0000000..2d8902e --- /dev/null +++ b/examples/wanvideo/model_training/full/LongCat-Video.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LongCat-Video_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/LongCat-Video.sh b/examples/wanvideo/model_training/lora/LongCat-Video.sh new file mode 100644 index 0000000..022048c --- /dev/null +++ b/examples/wanvideo/model_training/lora/LongCat-Video.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "meituan-longcat/LongCat-Video:dit/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LongCat-Video_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "adaLN_modulation.1,attn.qkv,attn.proj,cross_attn.q_linear,cross_attn.kv_linear,cross_attn.proj,ffn.w1,ffn.w2,ffn.w3" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/validate_full/LongCat-Video.py b/examples/wanvideo/model_training/validate_full/LongCat-Video.py new file mode 100644 index 0000000..31ee9a7 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/LongCat-Video.py @@ -0,0 +1,25 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/LongCat-Video_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_LongCat-Video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/LongCat-Video.py b/examples/wanvideo/model_training/validate_lora/LongCat-Video.py new file mode 100644 index 0000000..45c1ddb --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/LongCat-Video.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/LongCat-Video_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_LongCat-Video.mp4", fps=15, quality=5)