diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index e4afaaa..656d7c9 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -42,7 +42,7 @@ https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 ## Train -We support Wan-Video LoRA training. Here is a tutorial. +We support Wan-Video LoRA training and full training. Here is a tutorial. Step 1: Install additional packages @@ -99,9 +99,12 @@ data/example_dataset/ Step 4: Train +LoRA training: + ```shell CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ --task train \ + --train_architecture lora \ --dataset_path data/example_dataset \ --output_path ./models \ --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ @@ -115,8 +118,26 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ --use_gradient_checkpointing ``` +Full training: + +```shell +CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ + --task train \ + --train_architecture full \ + --dataset_path data/example_dataset \ + --output_path ./models \ + --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ + --steps_per_epoch 500 \ + --max_epochs 10 \ + --learning_rate 1e-4 \ + --accumulate_grad_batches 1 \ + --use_gradient_checkpointing +``` + Step 5: Test +Test LoRA: + ```python import torch from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData @@ -129,16 +150,39 @@ model_manager.load_models([ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", ]) model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0) - pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") pipe.enable_vram_management(num_persistent_param_in_dit=None) -# Text-to-video video = pipe( prompt="...", negative_prompt="...", num_inference_steps=50, seed=0, tiled=True ) -save_video(video, "video_with_lora.mp4", fps=30, quality=5) +save_video(video, "video.mp4", fps=30, quality=5) +``` + +Test fine-tuned base model: + +```python +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") +model_manager.load_models([ + "models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", +]) +pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +video = pipe( + prompt="...", + negative_prompt="...", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video.mp4", fps=30, quality=5) ``` diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 8bc2134..53de964 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -134,7 +134,7 @@ class TensorDataset(torch.utils.data.Dataset): class LightningModelForTrain(pl.LightningModule): - def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): + def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): super().__init__() model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") model_manager.load_models([dit_path]) @@ -142,13 +142,16 @@ class LightningModelForTrain(pl.LightningModule): self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.pipe.scheduler.set_timesteps(1000, training=True) self.freeze_parameters() - self.add_lora_to_model( - self.pipe.denoising_model(), - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_target_modules=lora_target_modules, - init_lora_weights=init_lora_weights, - ) + if train_architecture == "lora": + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + ) + else: + self.pipe.denoising_model().requires_grad_(True) self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing @@ -384,6 +387,13 @@ def parse_args(): action="store_true", help="Whether to use gradient checkpointing.", ) + parser.add_argument( + "--train_architecture", + type=str, + default="lora", + choices=["lora", "full"], + help="Model structure to train. LoRA training or full training.", + ) args = parser.parse_args() return args @@ -434,6 +444,7 @@ def train(args): model = LightningModelForTrain( dit_path=args.dit_path, learning_rate=args.learning_rate, + train_architecture=args.train_architecture, lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, lora_target_modules=args.lora_target_modules,