mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
accelerate
This commit is contained in:
@@ -212,6 +212,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoUnit_FunReference(),
|
WanVideoUnit_FunReference(),
|
||||||
WanVideoUnit_SpeedControl(),
|
WanVideoUnit_SpeedControl(),
|
||||||
WanVideoUnit_VACE(),
|
WanVideoUnit_VACE(),
|
||||||
|
WanVideoUnit_UnifiedSequenceParallel(),
|
||||||
WanVideoUnit_TeaCache(),
|
WanVideoUnit_TeaCache(),
|
||||||
WanVideoUnit_CfgMerger(),
|
WanVideoUnit_CfgMerger(),
|
||||||
]
|
]
|
||||||
@@ -375,6 +376,19 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_usp(self):
|
||||||
|
import torch.distributed as dist
|
||||||
|
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
||||||
|
dist.init_process_group(backend="nccl", init_method="env://")
|
||||||
|
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
initialize_model_parallel(
|
||||||
|
sequence_parallel_degree=dist.get_world_size(),
|
||||||
|
ring_degree=1,
|
||||||
|
ulysses_degree=dist.get_world_size(),
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
|
|
||||||
def enable_usp(self):
|
def enable_usp(self):
|
||||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
||||||
@@ -423,6 +437,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
# Initialize pipeline
|
# Initialize pipeline
|
||||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
if use_usp: pipe.initialize_usp()
|
||||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
@@ -434,6 +449,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
|
||||||
pipe.prompter.fetch_models(pipe.text_encoder)
|
pipe.prompter.fetch_models(pipe.text_encoder)
|
||||||
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
|
||||||
|
|
||||||
|
# Unified Sequence Parallel
|
||||||
|
if use_usp: pipe.enable_usp()
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -492,11 +510,11 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Inputs
|
# Inputs
|
||||||
inputs_posi = {
|
inputs_posi = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
|
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||||
}
|
}
|
||||||
inputs_nega = {
|
inputs_nega = {
|
||||||
"negative_prompt": negative_prompt,
|
"negative_prompt": negative_prompt,
|
||||||
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
|
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
|
||||||
}
|
}
|
||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
"input_image": input_image,
|
"input_image": input_image,
|
||||||
@@ -507,7 +525,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"height": height, "width": width, "num_frames": num_frames,
|
"height": height, "width": width, "num_frames": num_frames,
|
||||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||||
"num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift,
|
"sigma_shift": sigma_shift,
|
||||||
"motion_bucket_id": motion_bucket_id,
|
"motion_bucket_id": motion_bucket_id,
|
||||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||||
@@ -811,6 +829,18 @@ class WanVideoUnit_VACE(PipelineUnit):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(input_params=())
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline):
|
||||||
|
if hasattr(pipe, "use_unified_sequence_parallel"):
|
||||||
|
if pipe.use_unified_sequence_parallel:
|
||||||
|
return {"use_unified_sequence_parallel": True}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WanVideoUnit_TeaCache(PipelineUnit):
|
class WanVideoUnit_TeaCache(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ ModelConfig(path=[
|
|||||||
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
|
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。
|
||||||
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
|
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。
|
||||||
* `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。
|
* `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。
|
||||||
|
* `use_usp`: 是否启用 Unified Sequence Parallel,默认值为 `False`。用于多 GPU 并行推理。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -142,6 +143,23 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
|
|||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>推理加速</summary>
|
||||||
|
|
||||||
|
Wan 支持多种加速方案,包括
|
||||||
|
|
||||||
|
* 高效注意力机制实现:当您的 Python 环境中安装过这些注意力机制实现方案时,我们将会按照以下优先级自动启用。
|
||||||
|
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||||
|
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||||
|
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||||
|
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (默认设置,建议安装 `torch>=2.5.0`)
|
||||||
|
* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py)。
|
||||||
|
* TeaCache:加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<summary>输入参数</summary>
|
<summary>输入参数</summary>
|
||||||
|
|||||||
@@ -1,34 +1,27 @@
|
|||||||
import torch
|
import torch
|
||||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
from PIL import Image
|
||||||
from modelscope import snapshot_download
|
from diffsynth import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
# Download models
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
# Load models
|
model_configs=[
|
||||||
model_manager = ModelManager(device="cpu")
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
model_manager.load_models(
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
[
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
|
||||||
],
|
],
|
||||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
|
||||||
)
|
)
|
||||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
pipe.enable_vram_management()
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
|
||||||
|
|
||||||
# Text-to-video
|
|
||||||
video = pipe(
|
video = pipe(
|
||||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
num_inference_steps=50,
|
|
||||||
seed=0, tiled=True,
|
seed=0, tiled=True,
|
||||||
# TeaCache parameters
|
# TeaCache parameters
|
||||||
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
|
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
|
||||||
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
|
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
|
||||||
)
|
)
|
||||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
|
|
||||||
# TeaCache doesn't support video-to-video
|
|
||||||
27
examples/wanvideo/acceleration/unified_sequence_parallel.py
Normal file
27
examples/wanvideo/acceleration/unified_sequence_parallel.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
use_usp=True,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
# Download models
|
|
||||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(device="cpu")
|
|
||||||
model_manager.load_models(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
|
||||||
],
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
|
||||||
],
|
|
||||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
|
||||||
)
|
|
||||||
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="nccl",
|
|
||||||
init_method="env://",
|
|
||||||
)
|
|
||||||
from xfuser.core.distributed import (initialize_model_parallel,
|
|
||||||
init_distributed_environment)
|
|
||||||
init_distributed_environment(
|
|
||||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
|
||||||
|
|
||||||
initialize_model_parallel(
|
|
||||||
sequence_parallel_degree=dist.get_world_size(),
|
|
||||||
ring_degree=1,
|
|
||||||
ulysses_degree=dist.get_world_size(),
|
|
||||||
)
|
|
||||||
torch.cuda.set_device(dist.get_rank())
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device=f"cuda:{dist.get_rank()}",
|
|
||||||
use_usp=True if dist.get_world_size() > 1 else False)
|
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
|
||||||
|
|
||||||
# Text-to-video
|
|
||||||
video = pipe(
|
|
||||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
num_inference_steps=50,
|
|
||||||
seed=0, tiled=True
|
|
||||||
)
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
save_video(video, "video1.mp4", fps=25, quality=5)
|
|
||||||
Reference in New Issue
Block a user