mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
wan direct distill
This commit is contained in:
@@ -23,6 +23,7 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
|
||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
pipe.scheduler.training = True
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
@@ -5,8 +5,9 @@ def script_is_processed(output_path, script):
|
||||
return os.path.exists(os.path.join(output_path, script))
|
||||
|
||||
|
||||
def filter_unprocessed_tasks(script_path, output_path):
|
||||
def filter_unprocessed_tasks(script_path):
|
||||
tasks = []
|
||||
output_path = os.path.join("data", script_path)
|
||||
for script in sorted(os.listdir(script_path)):
|
||||
if not script.endswith(".sh") and not script.endswith(".py"):
|
||||
continue
|
||||
@@ -59,8 +60,8 @@ def run_train_multi_GPU(script_path, tasks):
|
||||
|
||||
|
||||
|
||||
def run_train_single_GPU(script_path):
|
||||
processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, i, 8)) for i in range(8)]
|
||||
def run_train_single_GPU(script_path, tasks):
|
||||
processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, 8)) for i in range(8)]
|
||||
for p in processes:
|
||||
p.start()
|
||||
for p in processes:
|
||||
@@ -85,8 +86,8 @@ if __name__ == "__main__":
|
||||
# run_train_single_GPU("examples/wanvideo/model_inference")
|
||||
# move_files("video_", "data/output/model_inference")
|
||||
# run_train_single_GPU("examples/wanvideo/model_training/lora")
|
||||
# run_train_single_GPU("examples/wanvideo/model_training/validate_lora")
|
||||
# move_files("video_", "data/output/validate_lora")
|
||||
run_train_single_GPU("examples/wanvideo/model_training/validate_lora", filter_unprocessed_tasks("examples/wanvideo/model_training/validate_lora"))
|
||||
move_files("video_", "data/output/validate_lora")
|
||||
# run_train_multi_GPU("examples/wanvideo/model_training/full")
|
||||
# run_train_single_GPU("examples/wanvideo/model_training/validate_full")
|
||||
# move_files("video_", "data/output/validate_full")
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_distill.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 160 \
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-T2V-1.3B_full_distill" \
|
||||
--trainable_models "dit" \
|
||||
--task "direct_distill" \
|
||||
--extra_inputs "seed,rand_device,num_inference_steps,cfg_scale"
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig("models/train/Wan2.1-T2V-1.3B_full_distill/epoch-1.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
cfg_scale=1, num_inference_steps=4,
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video_distill_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5)
|
||||
Reference in New Issue
Block a user