wan-series

This commit is contained in:
Artiprocher
2025-11-14 19:05:26 +08:00
parent 5be5c32fe4
commit e3356556ee
215 changed files with 5504 additions and 482 deletions

View File

@@ -9,6 +9,7 @@ from PIL import Image
from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from transformers import Wav2Vec2Processor
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
@@ -23,6 +24,7 @@ from ..models.wan_video_vace import VaceWanModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_animate_adapter import WanAnimateAdapter
from ..models.wan_video_mot import MotWanModel
from ..models.wav2vec import WanS2VAudioEncoder
from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
@@ -35,6 +37,7 @@ class WanVideoPipeline(BasePipeline):
)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.tokenizer: HuggingfaceTokenizer = None
self.audio_processor: Wav2Vec2Processor = None
self.text_encoder: WanTextEncoder = None
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
@@ -45,6 +48,7 @@ class WanVideoPipeline(BasePipeline):
self.vace2: VaceWanModel = None
self.vap: MotWanModel = None
self.animate_adapter: WanAnimateAdapter = None
self.audio_encoder: WanS2VAudioEncoder = None
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap")
self.units = [
@@ -96,7 +100,7 @@ class WanVideoPipeline(BasePipeline):
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None,
redirect_common_files: bool = True,
use_usp: bool = False,
@@ -105,16 +109,18 @@ class WanVideoPipeline(BasePipeline):
# Redirect model path
if redirect_common_files:
redirect_dict = {
"models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
"models_t5_umt5-xxl-enc-bf16.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors"),
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.safetensors"),
"Wan2.1_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.1_VAE.safetensors"),
"Wan2.2_VAE.pth": ("DiffSynth-Studio/Wan-Series-Converted-Safetensors", "Wan2.2_VAE.safetensors"),
}
for model_config in model_configs:
if model_config.origin_file_pattern is None or model_config.model_id is None:
continue
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern][0]:
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
@@ -153,11 +159,13 @@ class WanVideoPipeline(BasePipeline):
pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace')
if audio_processor_config is not None:
audio_processor_config.download_if_necessary()
from transformers import Wav2Vec2Processor
pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe