fix wans2v bug and update readme

This commit is contained in:
lzws
2026-02-05 16:52:38 +08:00
parent c0f7e1db7c
commit 8d172127cd
6 changed files with 120 additions and 11 deletions

View File

@@ -7,6 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--num_frames 81 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \
--audio_processor_path "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \
--learning_rate 1e-5 \
--num_epochs 1 \
--trainable_models "dit" \

View File

@@ -7,6 +7,7 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
--num_frames 81 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-S2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/model.safetensors,Wan-AI/Wan2.2-S2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-S2V-14B:Wan2.1_VAE.pth" \
--audio_processor_path "Wan-AI/Wan2.2-S2V-14B:wav2vec2-large-xlsr-53-english/" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \

View File

@@ -33,7 +33,7 @@ class WanTrainingModule(DiffusionTrainingModule):
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") if tokenizer_path is None else ModelConfig(tokenizer_path)
audio_processor_config = ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/") if audio_processor_path is None else ModelConfig(audio_processor_path)
audio_processor_config = self.parse_path_or_model_id(audio_processor_path)
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)