wan direct distill

This commit is contained in:
Artiprocher
2025-11-19 15:46:37 +08:00
parent 453ca89046
commit 6ad8d73717
4 changed files with 44 additions and 5 deletions

View File

@@ -23,6 +23,7 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)

View File

@@ -5,8 +5,9 @@ def script_is_processed(output_path, script):
return os.path.exists(os.path.join(output_path, script))
def filter_unprocessed_tasks(script_path, output_path):
def filter_unprocessed_tasks(script_path):
tasks = []
output_path = os.path.join("data", script_path)
for script in sorted(os.listdir(script_path)):
if not script.endswith(".sh") and not script.endswith(".py"):
continue
@@ -59,8 +60,8 @@ def run_train_multi_GPU(script_path, tasks):
def run_train_single_GPU(script_path):
processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, i, 8)) for i in range(8)]
def run_train_single_GPU(script_path, tasks):
processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, 8)) for i in range(8)]
for p in processes:
p.start()
for p in processes:
@@ -85,8 +86,8 @@ if __name__ == "__main__":
# run_train_single_GPU("examples/wanvideo/model_inference")
# move_files("video_", "data/output/model_inference")
# run_train_single_GPU("examples/wanvideo/model_training/lora")
# run_train_single_GPU("examples/wanvideo/model_training/validate_lora")
# move_files("video_", "data/output/validate_lora")
run_train_single_GPU("examples/wanvideo/model_training/validate_lora", filter_unprocessed_tasks("examples/wanvideo/model_training/validate_lora"))
move_files("video_", "data/output/validate_lora")
# run_train_multi_GPU("examples/wanvideo/model_training/full")
# run_train_single_GPU("examples/wanvideo/model_training/validate_full")
# move_files("video_", "data/output/validate_full")

View File

@@ -0,0 +1,14 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata_distill.csv \
--height 480 \
--width 832 \
--dataset_repeat 160 \
--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.1-T2V-1.3B_full_distill" \
--trainable_models "dit" \
--task "direct_distill" \
--extra_inputs "seed,rand_device,num_inference_steps,cfg_scale"

View File

@@ -0,0 +1,23 @@
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig("models/train/Wan2.1-T2V-1.3B_full_distill/epoch-1.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
cfg_scale=1, num_inference_steps=4,
seed=0, tiled=True,
)
save_video(video, "video_distill_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5)