diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py
index ff5eab4..40f3804 100644
--- a/diffsynth/models/wan_video_vace.py
+++ b/diffsynth/models/wan_video_vace.py
@@ -50,7 +50,11 @@ class VaceWanModel(torch.nn.Module):
# vace patch embeddings
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
- def forward(self, x, vace_context, context, t_mod, freqs):
+ def forward(
+ self, x, vace_context, context, t_mod, freqs,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ ):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
@@ -58,8 +62,27 @@ class VaceWanModel(torch.nn.Module):
dim=1) for u in c
])
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
for block in self.vace_blocks:
- c = block(c, x, context, t_mod, freqs)
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ c = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ c, x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ elif use_gradient_checkpointing:
+ c = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ c, x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1]
return hints
diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py
index 809d0e5..a167d8e 100644
--- a/diffsynth/pipelines/wan_video_new.py
+++ b/diffsynth/pipelines/wan_video_new.py
@@ -1,4 +1,4 @@
-import torch, warnings, glob, os
+import torch, warnings, glob, os, types
import numpy as np
from PIL import Image
from einops import repeat, reduce
@@ -213,6 +213,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_FunReference(),
WanVideoUnit_SpeedControl(),
WanVideoUnit_VACE(),
+ WanVideoUnit_UnifiedSequenceParallel(),
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
]
@@ -374,6 +375,30 @@ class WanVideoPipeline(BasePipeline):
),
vram_limit=vram_limit,
)
+
+
+ def initialize_usp(self):
+ import torch.distributed as dist
+ from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
+ dist.init_process_group(backend="nccl", init_method="env://")
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
+ initialize_model_parallel(
+ sequence_parallel_degree=dist.get_world_size(),
+ ring_degree=1,
+ ulysses_degree=dist.get_world_size(),
+ )
+ torch.cuda.set_device(dist.get_rank())
+
+
+ def enable_usp(self):
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
+
+ for block in self.dit.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
+ self.sp_size = get_sequence_parallel_world_size()
+ self.use_unified_sequence_parallel = True
@staticmethod
@@ -385,6 +410,7 @@ class WanVideoPipeline(BasePipeline):
local_model_path: str = "./models",
skip_download: bool = False,
redirect_common_files: bool = True,
+ use_usp=False,
):
# Redirect model path
if redirect_common_files:
@@ -412,6 +438,7 @@ class WanVideoPipeline(BasePipeline):
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
+ if use_usp: pipe.initialize_usp()
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
pipe.vae = model_manager.fetch_model("wan_video_vae")
@@ -423,6 +450,9 @@ class WanVideoPipeline(BasePipeline):
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
+
+ # Unified Sequence Parallel
+ if use_usp: pipe.enable_usp()
return pipe
@@ -483,11 +513,11 @@ class WanVideoPipeline(BasePipeline):
# Inputs
inputs_posi = {
"prompt": prompt,
- "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
+ "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_nega = {
"negative_prompt": negative_prompt,
- "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
+ "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_shared = {
"input_image": input_image,
@@ -499,7 +529,7 @@ class WanVideoPipeline(BasePipeline):
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
- "num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift,
+ "sigma_shift": sigma_shift,
"motion_bucket_id": motion_bucket_id,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
@@ -620,16 +650,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
- input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "denoising_strength"),
+ input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
onload_model_names=("vae",)
)
- def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, denoising_strength):
+ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(["vae"])
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ if vace_reference_image is not None:
+ vace_reference_image = pipe.preprocess_video([vace_reference_image])
+ vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
+ input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
@@ -829,6 +863,18 @@ class WanVideoUnit_VACE(PipelineUnit):
+class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=())
+
+ def process(self, pipe: WanVideoPipeline):
+ if hasattr(pipe, "use_unified_sequence_parallel"):
+ if pipe.use_unified_sequence_parallel:
+ return {"use_unified_sequence_parallel": True}
+ return {}
+
+
+
class WanVideoUnit_TeaCache(PipelineUnit):
def __init__(self):
super().__init__(
diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md
index bb73ca1..8511e9c 100644
--- a/examples/wanvideo/README_zh.md
+++ b/examples/wanvideo/README_zh.md
@@ -13,13 +13,13 @@
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|基础模型|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
-|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|基础模型|`input_image`, `end_image`||||||
-|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|基础模型|`input_image`, `end_image`||||||
+|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
+|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|基础模型|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|基础模型|||||||
-|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/VACE-Wan2.1-1.3B-Preview.sh)|[code](./model_training/validate_full/VACE-Wan2.1-1.3B-Preview.py)|[code](./model_training/lora/VACE-Wan2.1-1.3B-Preview.sh)|[code](./model_training/validate_lora/VACE-Wan2.1-1.3B-Preview.py)|
-|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|||||
-|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|||||
+|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
+|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
+|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|适配器|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
## 模型推理
@@ -78,6 +78,7 @@ ModelConfig(path=[
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
* `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。
+* `use_usp`: 是否启用 Unified Sequence Parallel,默认值为 `False`。用于多 GPU 并行推理。
@@ -142,6 +143,23 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
+
+
+推理加速
+
+Wan 支持多种加速方案,包括
+
+* 高效注意力机制实现:当您的 Python 环境中安装过这些注意力机制实现方案时,我们将会按照以下优先级自动启用。
+ * [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
+ * [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
+ * [Sage Attention](https://github.com/thu-ml/SageAttention)
+ * [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (默认设置,建议安装 `torch>=2.5.0`)
+* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用命令 `torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py` 运行。
+* TeaCache:加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
+
+
+
+
输入参数
@@ -224,6 +242,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* 显存管理
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
+此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如 14B 模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
+
diff --git a/examples/wanvideo/wan_1.3b_text_to_video_accelerate.py b/examples/wanvideo/acceleration/teacache.py
similarity index 59%
rename from examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
rename to examples/wanvideo/acceleration/teacache.py
index b56915c..b88656a 100644
--- a/examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
+++ b/examples/wanvideo/acceleration/teacache.py
@@ -1,34 +1,27 @@
import torch
-from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
-from modelscope import snapshot_download
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
-# Download models
-snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
-
-# Load models
-model_manager = ModelManager(device="cpu")
-model_manager.load_models(
- [
- "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
- "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
- "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
- torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
-pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
-pipe.enable_vram_management(num_persistent_param_in_dit=None)
+pipe.enable_vram_management()
+
-# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
- num_inference_steps=50,
seed=0, tiled=True,
# TeaCache parameters
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
)
save_video(video, "video1.mp4", fps=15, quality=5)
-
-# TeaCache doesn't support video-to-video
diff --git a/examples/wanvideo/acceleration/unified_sequence_parallel.py b/examples/wanvideo/acceleration/unified_sequence_parallel.py
new file mode 100644
index 0000000..44b580b
--- /dev/null
+++ b/examples/wanvideo/acceleration/unified_sequence_parallel.py
@@ -0,0 +1,27 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+import torch.distributed as dist
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ use_usp=True,
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+pipe.enable_vram_management()
+
+
+video = pipe(
+ prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ seed=0, tiled=True,
+)
+if dist.get_rank() == 0:
+ save_video(video, "video1.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py
new file mode 100644
index 0000000..f2fc560
--- /dev/null
+++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py
@@ -0,0 +1,36 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+from modelscope import dataset_snapshot_download
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
+ ],
+)
+pipe.enable_vram_management()
+
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/examples_in_diffsynth",
+ local_dir="./",
+ allow_file_pattern=f"data/examples/wan/input_image.jpg"
+)
+image = Image.open("data/examples/wan/input_image.jpg")
+
+# First and last frame to video
+video = pipe(
+ prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ input_image=image,
+ seed=0, tiled=True
+ # You can input `end_image=xxx` to control the last frame of the video.
+ # The model will automatically generate the dynamic content between `input_image` and `end_image`.
+)
+save_video(video, "video.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py
new file mode 100644
index 0000000..334e981
--- /dev/null
+++ b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py
@@ -0,0 +1,36 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+from modelscope import dataset_snapshot_download
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
+ ],
+)
+pipe.enable_vram_management()
+
+dataset_snapshot_download(
+ dataset_id="DiffSynth-Studio/examples_in_diffsynth",
+ local_dir="./",
+ allow_file_pattern=f"data/examples/wan/input_image.jpg"
+)
+image = Image.open("data/examples/wan/input_image.jpg")
+
+# First and last frame to video
+video = pipe(
+ prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ input_image=image,
+ seed=0, tiled=True
+ # You can input `end_image=xxx` to control the last frame of the video.
+ # The model will automatically generate the dynamic content between `input_image` and `end_image`.
+)
+save_video(video, "video.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh
new file mode 100644
index 0000000..d3b280f
--- /dev/null
+++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh
@@ -0,0 +1,14 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata.csv \
+ --height 480 \
+ --width 832 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_full" \
+ --trainable_models "dit" \
+ --input_contains_input_image \
+ --input_contains_end_image
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh
new file mode 100644
index 0000000..11e7cc3
--- /dev/null
+++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh
@@ -0,0 +1,14 @@
+accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata.csv \
+ --height 480 \
+ --width 832 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
+ --learning_rate 1e-5 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_full" \
+ --trainable_models "dit" \
+ --input_contains_input_image \
+ --input_contains_end_image
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh
new file mode 100644
index 0000000..9fb6c3e
--- /dev/null
+++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh
@@ -0,0 +1,17 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
+ --data_file_keys "video,vace_video,vace_reference_image" \
+ --height 480 \
+ --width 832 \
+ --num_frames 49 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.vace." \
+ --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_full" \
+ --trainable_models "vace" \
+ --input_contains_vace_video \
+ --input_contains_vace_reference_image \
+ --use_gradient_checkpointing_offload
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh
new file mode 100644
index 0000000..1479475
--- /dev/null
+++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh
@@ -0,0 +1,17 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
+ --data_file_keys "video,vace_video,vace_reference_image" \
+ --height 480 \
+ --width 832 \
+ --num_frames 49 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.vace." \
+ --output_path "./models/train/Wan2.1-VACE-1.3B_full" \
+ --trainable_models "vace" \
+ --input_contains_vace_video \
+ --input_contains_vace_reference_image \
+ --use_gradient_checkpointing_offload
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh
new file mode 100644
index 0000000..85fc317
--- /dev/null
+++ b/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh
@@ -0,0 +1,17 @@
+accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
+ --data_file_keys "video,vace_video,vace_reference_image" \
+ --height 480 \
+ --width 832 \
+ --num_frames 17 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 2 \
+ --remove_prefix_in_ckpt "pipe.vace." \
+ --output_path "./models/train/Wan2.1-VACE-14B_full" \
+ --trainable_models "vace" \
+ --input_contains_vace_video \
+ --input_contains_vace_reference_image \
+ --use_gradient_checkpointing_offload
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh
new file mode 100644
index 0000000..b3a582a
--- /dev/null
+++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh
@@ -0,0 +1,16 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata.csv \
+ --height 480 \
+ --width 832 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_lora" \
+ --lora_base_model "dit" \
+ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
+ --lora_rank 32 \
+ --input_contains_input_image \
+ --input_contains_end_image
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh
new file mode 100644
index 0000000..91fead7
--- /dev/null
+++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh
@@ -0,0 +1,16 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata.csv \
+ --height 480 \
+ --width 832 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.dit." \
+ --output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_lora" \
+ --lora_base_model "dit" \
+ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
+ --lora_rank 32 \
+ --input_contains_input_image \
+ --input_contains_end_image
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh
new file mode 100644
index 0000000..85dff46
--- /dev/null
+++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh
@@ -0,0 +1,18 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
+ --data_file_keys "video,vace_video,vace_reference_image" \
+ --height 480 \
+ --width 832 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "iic/VACE-Wan2.1-1.3B-Preview:diffusion_pytorch_model*.safetensors,iic/VACE-Wan2.1-1.3B-Preview:models_t5_umt5-xxl-enc-bf16.pth,iic/VACE-Wan2.1-1.3B-Preview:Wan2.1_VAE.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.vace." \
+ --output_path "./models/train/Wan2.1-VACE-1.3B-Preview_lora" \
+ --lora_base_model "vace" \
+ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
+ --lora_rank 32 \
+ --input_contains_vace_video \
+ --input_contains_vace_reference_image \
+ --use_gradient_checkpointing_offload
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh
new file mode 100644
index 0000000..0845e16
--- /dev/null
+++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh
@@ -0,0 +1,18 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
+ --data_file_keys "video,vace_video,vace_reference_image" \
+ --height 480 \
+ --width 832 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.vace." \
+ --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \
+ --lora_base_model "vace" \
+ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
+ --lora_rank 32 \
+ --input_contains_vace_video \
+ --input_contains_vace_reference_image \
+ --use_gradient_checkpointing_offload
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh
new file mode 100644
index 0000000..7d596ed
--- /dev/null
+++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh
@@ -0,0 +1,19 @@
+accelerate launch examples/wanvideo/model_training/train.py \
+ --dataset_base_path data/example_video_dataset \
+ --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \
+ --data_file_keys "video,vace_video,vace_reference_image" \
+ --height 480 \
+ --width 832 \
+ --num_frames 17 \
+ --dataset_repeat 100 \
+ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-14B:Wan2.1_VAE.pth" \
+ --learning_rate 1e-4 \
+ --num_epochs 5 \
+ --remove_prefix_in_ckpt "pipe.vace." \
+ --output_path "./models/train/Wan2.1-VACE-14B_lora" \
+ --lora_base_model "vace" \
+ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
+ --lora_rank 32 \
+ --input_contains_vace_video \
+ --input_contains_vace_reference_image \
+ --use_gradient_checkpointing_offload
\ No newline at end of file
diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py
new file mode 100644
index 0000000..cd8ee20
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py
@@ -0,0 +1,31 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData, load_state_dict
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+from modelscope import dataset_snapshot_download
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
+ ],
+)
+state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-InP_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
+
+# First and last frame to video
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ input_image=video[0], end_image=video[80],
+ seed=0, tiled=True
+)
+save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py
new file mode 100644
index 0000000..7e944b0
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py
@@ -0,0 +1,31 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData, load_state_dict
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+from modelscope import dataset_snapshot_download
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
+ ],
+)
+state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-InP_full/epoch-1.safetensors")
+pipe.dit.load_state_dict(state_dict)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
+
+# First and last frame to video
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ input_image=video[0], end_image=video[80],
+ seed=0, tiled=True
+)
+save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py
new file mode 100644
index 0000000..7db26e0
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py
@@ -0,0 +1,30 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData, load_state_dict
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+state_dict = load_state_dict("models/train/VACE-Wan2.1-1.3B-Preview_full/epoch-1.safetensors")
+pipe.vace.load_state_dict(state_dict)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
+video = [video[i] for i in range(49)]
+reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
+
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ vace_video=video, vace_reference_image=reference_image, num_frames=49,
+ seed=1, tiled=True
+)
+save_video(video, "video_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py
new file mode 100644
index 0000000..5a371e7
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py
@@ -0,0 +1,30 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData, load_state_dict
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+state_dict = load_state_dict("models/train/Wan2.1-VACE-1.3B_full/epoch-1.safetensors")
+pipe.vace.load_state_dict(state_dict)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
+video = [video[i] for i in range(49)]
+reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
+
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ vace_video=video, vace_reference_image=reference_image, num_frames=49,
+ seed=1, tiled=True
+)
+save_video(video, "video_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py
new file mode 100644
index 0000000..5553471
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py
@@ -0,0 +1,30 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData, load_state_dict
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+state_dict = load_state_dict("models/train/Wan2.1-VACE-14B_full/epoch-1.safetensors")
+pipe.vace.load_state_dict(state_dict)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
+video = [video[i] for i in range(17)]
+reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
+
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ vace_video=video, vace_reference_image=reference_image, num_frames=17,
+ seed=1, tiled=True
+)
+save_video(video, "video_Wan2.1-VACE-14B.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py
new file mode 100644
index 0000000..99eb2b4
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py
@@ -0,0 +1,30 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+from modelscope import dataset_snapshot_download
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
+ ],
+)
+pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-InP_lora/epoch-4.safetensors", alpha=1)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
+
+# First and last frame to video
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ input_image=video[0], end_image=video[80],
+ seed=0, tiled=True
+)
+save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py
new file mode 100644
index 0000000..35088fb
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py
@@ -0,0 +1,30 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+from modelscope import dataset_snapshot_download
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
+ ],
+)
+pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-14B-InP_lora/epoch-4.safetensors", alpha=1)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
+
+# First and last frame to video
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ input_image=video[0], end_image=video[80],
+ seed=0, tiled=True
+)
+save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py
new file mode 100644
index 0000000..91cbf92
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py
@@ -0,0 +1,29 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+pipe.load_lora(pipe.vace, "models/train/Wan2.1-VACE-1.3B-Preview_lora/epoch-4.safetensors", alpha=1)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
+video = [video[i] for i in range(49)]
+reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
+
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ vace_video=video, vace_reference_image=reference_image, num_frames=49,
+ seed=1, tiled=True
+)
+save_video(video, "video_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py
new file mode 100644
index 0000000..b5fd203
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py
@@ -0,0 +1,29 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+pipe.load_lora(pipe.vace, "models/train/Wan2.1-VACE-1.3B_lora/epoch-4.safetensors", alpha=1)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
+video = [video[i] for i in range(49)]
+reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
+
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ vace_video=video, vace_reference_image=reference_image, num_frames=49,
+ seed=1, tiled=True
+)
+save_video(video, "video_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py
new file mode 100644
index 0000000..bec5df3
--- /dev/null
+++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py
@@ -0,0 +1,29 @@
+import torch
+from PIL import Image
+from diffsynth import save_video, VideoData
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+pipe.load_lora(pipe.vace, "models/train/Wan2.1-VACE-14B_lora/epoch-4.safetensors", alpha=1)
+pipe.enable_vram_management()
+
+video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
+video = [video[i] for i in range(17)]
+reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
+
+video = pipe(
+ prompt="from sunset to night, a small town, light, house, river",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ vace_video=video, vace_reference_image=reference_image, num_frames=17,
+ seed=1, tiled=True
+)
+save_video(video, "video_Wan2.1-VACE-14B.mp4", fps=15, quality=5)
diff --git a/examples/wanvideo/wan_14b_text_to_video_usp.py b/examples/wanvideo/wan_14b_text_to_video_usp.py
deleted file mode 100644
index 8837294..0000000
--- a/examples/wanvideo/wan_14b_text_to_video_usp.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import torch
-from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
-from modelscope import snapshot_download
-import torch.distributed as dist
-
-
-# Download models
-snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
-
-# Load models
-model_manager = ModelManager(device="cpu")
-model_manager.load_models(
- [
- [
- "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
- "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
- "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
- "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
- "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
- "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
- ],
- "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
- "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
- ],
- torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
-)
-
-dist.init_process_group(
- backend="nccl",
- init_method="env://",
-)
-from xfuser.core.distributed import (initialize_model_parallel,
- init_distributed_environment)
-init_distributed_environment(
- rank=dist.get_rank(), world_size=dist.get_world_size())
-
-initialize_model_parallel(
- sequence_parallel_degree=dist.get_world_size(),
- ring_degree=1,
- ulysses_degree=dist.get_world_size(),
-)
-torch.cuda.set_device(dist.get_rank())
-
-pipe = WanVideoPipeline.from_model_manager(model_manager,
- torch_dtype=torch.bfloat16,
- device=f"cuda:{dist.get_rank()}",
- use_usp=True if dist.get_world_size() > 1 else False)
-pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
-
-# Text-to-video
-video = pipe(
- prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
- negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
- num_inference_steps=50,
- seed=0, tiled=True
-)
-if dist.get_rank() == 0:
- save_video(video, "video1.mp4", fps=25, quality=5)