From 5418ca781e12218417c21c856c998d4d48fea6f2 Mon Sep 17 00:00:00 2001 From: "lzw478614@alibaba-inc.com" Date: Thu, 3 Apr 2025 16:37:59 +0800 Subject: [PATCH 1/5] support load wan2.1-fun-inp-1.3B and 14B model --- diffsynth/configs/model_config.py | 2 ++ diffsynth/models/wan_video_dit.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 969afae..9853a47 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -120,6 +120,8 @@ model_loader_configs = [ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 650e08f..b3692ae 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -493,6 +493,34 @@ class WanModelStateDictConverter: "num_layers": 40, "eps": 1e-6 } + elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } else: config = {} return state_dict, config From a98700feb24329b3b8d9defac41de844c10b9a7a Mon Sep 17 00:00:00 2001 From: "lzw478614@alibaba-inc.com" Date: Sun, 6 Apr 2025 22:55:42 +0800 Subject: [PATCH 2/5] support wan-fun-inp generating --- diffsynth/pipelines/wan_video.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index b6f2c74..6b95a69 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -163,16 +163,22 @@ class WanVideoPipeline(BasePipeline): return {"context": prompt_emb} - def encode_image(self, image, num_frames, height, width): + def encode_image(self, image, end_image, num_frames, height, width): image = self.preprocess_image(image.resize((width, height))).to(self.device) clip_context = self.image_encoder.encode_image([image]) msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) msk[:, 1:] = 0 + if end_image is not None: + end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] - vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0] y = torch.concat([msk, y]) y = y.unsqueeze(0) @@ -212,6 +218,7 @@ class WanVideoPipeline(BasePipeline): prompt, negative_prompt="", input_image=None, + end_image=None, input_video=None, denoising_strength=1.0, seed=None, @@ -263,7 +270,7 @@ class WanVideoPipeline(BasePipeline): # Encode image if input_image is not None and self.image_encoder is not None: self.load_models_to_device(["image_encoder", "vae"]) - image_emb = self.encode_image(input_image, num_frames, height, width) + image_emb = self.encode_image(input_image, end_image, num_frames, height, width) else: image_emb = {} From 60a9db706ecdbc56f62827f812b80849ec3a5a55 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 8 Apr 2025 17:07:10 +0800 Subject: [PATCH 3/5] support more wan models --- diffsynth/configs/model_config.py | 3 + diffsynth/models/wan_video_dit.py | 14 ++++ .../models/wan_video_motion_controller.py | 44 ++++++++++ diffsynth/pipelines/wan_video.py | 84 +++++++++++++++++-- examples/wanvideo/README.md | 81 ++++++++++-------- .../wanvideo/wan_1.3b_motion_controller.py | 41 +++++++++ examples/wanvideo/wan_fun_InP.py | 42 ++++++++++ examples/wanvideo/wan_fun_control.py | 40 +++++++++ 8 files changed, 307 insertions(+), 42 deletions(-) create mode 100644 diffsynth/models/wan_video_motion_controller.py create mode 100644 examples/wanvideo/wan_1.3b_motion_controller.py create mode 100644 examples/wanvideo/wan_fun_InP.py create mode 100644 examples/wanvideo/wan_fun_control.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 9853a47..052d63e 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -59,6 +59,7 @@ from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_motion_controller import WanMotionControllerModel model_loader_configs = [ @@ -122,11 +123,13 @@ model_loader_configs = [ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"), (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), + (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index b3692ae..d423c4b 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -521,6 +521,20 @@ class WanModelStateDictConverter: "num_layers": 40, "eps": 1e-6 } + elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6 + } else: config = {} return state_dict, config diff --git a/diffsynth/models/wan_video_motion_controller.py b/diffsynth/models/wan_video_motion_controller.py new file mode 100644 index 0000000..518c1c6 --- /dev/null +++ b/diffsynth/models/wan_video_motion_controller.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from .wan_video_dit import sinusoidal_embedding_1d + + + +class WanMotionControllerModel(torch.nn.Module): + def __init__(self, freq_dim=256, dim=1536): + super().__init__() + self.freq_dim = freq_dim + self.linear = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6), + ) + + def forward(self, motion_bucket_id): + emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10) + emb = self.linear(emb) + return emb + + def init(self): + state_dict = self.linear[-1].state_dict() + state_dict = {i: state_dict[i] * 0 for i in state_dict} + self.linear[-1].load_state_dict(state_dict) + + @staticmethod + def state_dict_converter(): + return WanMotionControllerModelDictConverter() + + + +class WanMotionControllerModelDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict + diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 6b95a69..9a80f78 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -18,6 +18,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample +from ..models.wan_video_motion_controller import WanMotionControllerModel @@ -31,7 +32,8 @@ class WanVideoPipeline(BasePipeline): self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.vae: WanVideoVAE = None - self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder'] + self.motion_controller: WanMotionControllerModel = None + self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller'] self.height_division_factor = 16 self.width_division_factor = 16 self.use_unified_sequence_parallel = False @@ -122,6 +124,22 @@ class WanVideoPipeline(BasePipeline): computation_device=self.device, ), ) + if self.motion_controller is not None: + dtype = next(iter(self.motion_controller.parameters())).dtype + enable_vram_management( + self.motion_controller, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) self.enable_cpu_offload() @@ -134,6 +152,7 @@ class WanVideoPipeline(BasePipeline): self.dit = model_manager.fetch_model("wan_video_dit") self.vae = model_manager.fetch_model("wan_video_vae") self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") + self.motion_controller = model_manager.fetch_model("wan_video_motion_controller") @staticmethod @@ -185,6 +204,25 @@ class WanVideoPipeline(BasePipeline): clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device) y = y.to(dtype=self.torch_dtype, device=self.device) return {"clip_feature": clip_context, "y": y} + + + def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + control_video = self.preprocess_images(control_video) + control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device) + latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) + return latents + + + def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + if control_video is not None: + control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if clip_feature is None or y is None: + clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device) + y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device) + else: + y = y[:, -16:] + y = torch.concat([control_latents, y], dim=1) + return {"clip_feature": clip_feature, "y": y} def tensor2video(self, frames): @@ -210,6 +248,11 @@ class WanVideoPipeline(BasePipeline): def prepare_unified_sequence_parallel(self): return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} + + + def prepare_motion_bucket_id(self, motion_bucket_id): + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device) + return {"motion_bucket_id": motion_bucket_id} @torch.no_grad() @@ -220,6 +263,7 @@ class WanVideoPipeline(BasePipeline): input_image=None, end_image=None, input_video=None, + control_video=None, denoising_strength=1.0, seed=None, rand_device="cpu", @@ -229,6 +273,7 @@ class WanVideoPipeline(BasePipeline): cfg_scale=5.0, num_inference_steps=50, sigma_shift=5.0, + motion_bucket_id=None, tiled=True, tile_size=(30, 52), tile_stride=(15, 26), @@ -274,6 +319,17 @@ class WanVideoPipeline(BasePipeline): else: image_emb = {} + # ControlNet + if control_video is not None: + self.load_models_to_device(["image_encoder", "vae"]) + image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs) + + # Motion Controller + if self.motion_controller is not None and motion_bucket_id is not None: + motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id) + else: + motion_kwargs = {} + # Extra input extra_input = self.prepare_extra_input(latents) @@ -285,14 +341,24 @@ class WanVideoPipeline(BasePipeline): usp_kwargs = self.prepare_unified_sequence_parallel() # Denoise - self.load_models_to_device(["dit"]) + self.load_models_to_device(["dit", "motion_controller"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference - noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs) + noise_pred_posi = model_fn_wan_video( + self.dit, motion_controller=self.motion_controller, + x=latents, timestep=timestep, + **prompt_emb_posi, **image_emb, **extra_input, + **tea_cache_posi, **usp_kwargs, **motion_kwargs + ) if cfg_scale != 1.0: - noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs) + noise_pred_nega = model_fn_wan_video( + self.dit, motion_controller=self.motion_controller, + x=latents, timestep=timestep, + **prompt_emb_nega, **image_emb, **extra_input, + **tea_cache_nega, **usp_kwargs, **motion_kwargs + ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -365,13 +431,15 @@ class TeaCache: def model_fn_wan_video( dit: WanModel, - x: torch.Tensor, - timestep: torch.Tensor, - context: torch.Tensor, + motion_controller: WanMotionControllerModel = None, + x: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, + motion_bucket_id: Optional[torch.Tensor] = None, **kwargs, ): if use_unified_sequence_parallel: @@ -382,6 +450,8 @@ def model_fn_wan_video( t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) if dit.has_image_input: diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index f8b3e0b..f3064a0 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -10,34 +10,30 @@ cd DiffSynth-Studio pip install -e . ``` -Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. +## Model Zoo -* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) -* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) -* [Sage Attention](https://github.com/thu-ml/SageAttention) -* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) +|Developer|Name|Link|Scripts| +|-|-|-|-| +|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)| +|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)| +|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| +|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| +|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).| +|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).| +|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).| +|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)| +|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| +|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| +|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)| +|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)| -## Inference +## VRAM Usage -### Wan-Video-1.3B-T2V +* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). -Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py). +* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!). -Required VRAM: 6G - -https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8 - -Put sunglasses on the dog. - -https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb - -[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py). - -### Wan-Video-14B-T2V - -Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). - -We present a detailed table here. The model is tested on a single A100. +We present a detailed table here. The model (14B text-to-video) is tested on a single A100. |`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting| |-|-|-|-|-| @@ -47,31 +43,46 @@ We present a detailed table here. The model is tested on a single A100. |torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes| |torch.float8_e4m3fn|0|24.0s/it|10G|| -https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f +**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.** -### Parallel Inference +## Efficient Attention Implementation -1. Unified Sequence Parallel (USP) +DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA. + +* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) +* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) +* [Sage Attention](https://github.com/thu-ml/SageAttention) +* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) + +## Acceleration + +We support multiple acceleration solutions: +* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py). + +* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py) ```bash pip install xfuser>=0.4.3 -``` - -```bash torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py ``` -2. Tensor Parallel +* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py). -Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py). +## Gallery -### Wan-Video-14B-I2V +1.3B text-to-video. -Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py). +https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8 -**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.** +Put sunglasses on the dog. -![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39) +https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb + +14B text-to-video. + +https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f + +14B image-to-video. https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 diff --git a/examples/wanvideo/wan_1.3b_motion_controller.py b/examples/wanvideo/wan_1.3b_motion_controller.py new file mode 100644 index 0000000..8036819 --- /dev/null +++ b/examples/wanvideo/wan_1.3b_motion_controller.py @@ -0,0 +1,41 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") +snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + "models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=1, tiled=True, + motion_bucket_id=0 +) +save_video(video, "video_slow.mp4", fps=15, quality=5) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=1, tiled=True, + motion_bucket_id=100 +) +save_video(video, "video_fast.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/wan_fun_InP.py b/examples/wanvideo/wan_fun_InP.py new file mode 100644 index 0000000..ae23ee0 --- /dev/null +++ b/examples/wanvideo/wan_fun_InP.py @@ -0,0 +1,42 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors", + "models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth", + "models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth", + "models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Download example image +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + input_image=image, + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_fun_control.py b/examples/wanvideo/wan_fun_control.py new file mode 100644 index 0000000..e2c4d0c --- /dev/null +++ b/examples/wanvideo/wan_fun_control.py @@ -0,0 +1,40 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors", + "models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth", + "models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth", + "models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Download example video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/control_video.mp4" +) + +# Control-to-video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + control_video=control_video, height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) From f6c6e3c640b2c4fcc3673484c18f64f722f9a4be Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 8 Apr 2025 17:19:54 +0800 Subject: [PATCH 4/5] support more wan models --- examples/wanvideo/README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index f3064a0..3530976 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -27,6 +27,28 @@ pip install -e . |PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)| |PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)| +Base model features + +||Text-to-video|Image-to-video|End frame|Control| +|-|-|-|-|-| +|1.3B text-to-video|✅|||| +|14B text-to-video|✅|||| +|14B image-to-video 480P||✅||| +|14B image-to-video 720P||✅||| +|1.3B InP||✅|✅|| +|14B InP||✅|✅|| +|1.3B Control||||✅| +|14B Control||||✅| + +Adapter model compatibility + +||1.3B text-to-video|1.3B InP| +|-|-|-| +|1.3B aesthetics LoRA|✅|| +|1.3B Highres-fix LoRA|✅|| +|1.3B ExVideo LoRA|✅|| +|1.3B Speed Control adapter|✅|✅| + ## VRAM Usage * Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). From 3cc9764bc90fe9b4e5dca7a0118bb3a81ee8987d Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 8 Apr 2025 19:22:53 +0800 Subject: [PATCH 5/5] support more wan models --- diffsynth/configs/model_config.py | 1 + diffsynth/models/wan_video_dit.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 052d63e..8fdb50a 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -124,6 +124,7 @@ model_loader_configs = [ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"), (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), + (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index d423c4b..c999596 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -535,6 +535,20 @@ class WanModelStateDictConverter: "num_layers": 30, "eps": 1e-6 } + elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } else: config = {} return state_dict, config