diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index bb12dfd..2e9d390 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -23,6 +23,7 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): def DirectDistillLoss(pipe: BasePipeline, **inputs): pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) + pipe.scheduler.training = True models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} for progress_id, timestep in enumerate(pipe.scheduler.timesteps): timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) diff --git a/examples/test/run.py b/examples/test/run.py index 097800a..7c78043 100644 --- a/examples/test/run.py +++ b/examples/test/run.py @@ -5,8 +5,9 @@ def script_is_processed(output_path, script): return os.path.exists(os.path.join(output_path, script)) -def filter_unprocessed_tasks(script_path, output_path): +def filter_unprocessed_tasks(script_path): tasks = [] + output_path = os.path.join("data", script_path) for script in sorted(os.listdir(script_path)): if not script.endswith(".sh") and not script.endswith(".py"): continue @@ -59,8 +60,8 @@ def run_train_multi_GPU(script_path, tasks): -def run_train_single_GPU(script_path): - processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, i, 8)) for i in range(8)] +def run_train_single_GPU(script_path, tasks): + processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, 8)) for i in range(8)] for p in processes: p.start() for p in processes: @@ -85,8 +86,8 @@ if __name__ == "__main__": # run_train_single_GPU("examples/wanvideo/model_inference") # move_files("video_", "data/output/model_inference") # run_train_single_GPU("examples/wanvideo/model_training/lora") - # run_train_single_GPU("examples/wanvideo/model_training/validate_lora") - # move_files("video_", "data/output/validate_lora") + run_train_single_GPU("examples/wanvideo/model_training/validate_lora", filter_unprocessed_tasks("examples/wanvideo/model_training/validate_lora")) + move_files("video_", "data/output/validate_lora") # run_train_multi_GPU("examples/wanvideo/model_training/full") # run_train_single_GPU("examples/wanvideo/model_training/validate_full") # move_files("video_", "data/output/validate_full") diff --git a/examples/wanvideo/model_training/special/direct_distill/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/special/direct_distill/Wan2.1-T2V-1.3B.sh new file mode 100644 index 0000000..73e85f3 --- /dev/null +++ b/examples/wanvideo/model_training/special/direct_distill/Wan2.1-T2V-1.3B.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_distill.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 160 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-1.3B_full_distill" \ + --trainable_models "dit" \ + --task "direct_distill" \ + --extra_inputs "seed,rand_device,num_inference_steps,cfg_scale" diff --git a/examples/wanvideo/model_training/special/direct_distill/validate.py b/examples/wanvideo/model_training/special/direct_distill/validate.py new file mode 100644 index 0000000..6da0e1b --- /dev/null +++ b/examples/wanvideo/model_training/special/direct_distill/validate.py @@ -0,0 +1,23 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig("models/train/Wan2.1-T2V-1.3B_full_distill/epoch-1.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) + +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + cfg_scale=1, num_inference_steps=4, + seed=0, tiled=True, +) +save_video(video, "video_distill_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5)