mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
wan direct distill
This commit is contained in:
@@ -23,6 +23,7 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
|||||||
|
|
||||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||||
|
pipe.scheduler.training = True
|
||||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ def script_is_processed(output_path, script):
|
|||||||
return os.path.exists(os.path.join(output_path, script))
|
return os.path.exists(os.path.join(output_path, script))
|
||||||
|
|
||||||
|
|
||||||
def filter_unprocessed_tasks(script_path, output_path):
|
def filter_unprocessed_tasks(script_path):
|
||||||
tasks = []
|
tasks = []
|
||||||
|
output_path = os.path.join("data", script_path)
|
||||||
for script in sorted(os.listdir(script_path)):
|
for script in sorted(os.listdir(script_path)):
|
||||||
if not script.endswith(".sh") and not script.endswith(".py"):
|
if not script.endswith(".sh") and not script.endswith(".py"):
|
||||||
continue
|
continue
|
||||||
@@ -59,8 +60,8 @@ def run_train_multi_GPU(script_path, tasks):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_train_single_GPU(script_path):
|
def run_train_single_GPU(script_path, tasks):
|
||||||
processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, i, 8)) for i in range(8)]
|
processes = [multiprocessing.Process(target=run_tasks_on_single_GPU, args=(script_path, tasks, i, 8)) for i in range(8)]
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.start()
|
p.start()
|
||||||
for p in processes:
|
for p in processes:
|
||||||
@@ -85,8 +86,8 @@ if __name__ == "__main__":
|
|||||||
# run_train_single_GPU("examples/wanvideo/model_inference")
|
# run_train_single_GPU("examples/wanvideo/model_inference")
|
||||||
# move_files("video_", "data/output/model_inference")
|
# move_files("video_", "data/output/model_inference")
|
||||||
# run_train_single_GPU("examples/wanvideo/model_training/lora")
|
# run_train_single_GPU("examples/wanvideo/model_training/lora")
|
||||||
# run_train_single_GPU("examples/wanvideo/model_training/validate_lora")
|
run_train_single_GPU("examples/wanvideo/model_training/validate_lora", filter_unprocessed_tasks("examples/wanvideo/model_training/validate_lora"))
|
||||||
# move_files("video_", "data/output/validate_lora")
|
move_files("video_", "data/output/validate_lora")
|
||||||
# run_train_multi_GPU("examples/wanvideo/model_training/full")
|
# run_train_multi_GPU("examples/wanvideo/model_training/full")
|
||||||
# run_train_single_GPU("examples/wanvideo/model_training/validate_full")
|
# run_train_single_GPU("examples/wanvideo/model_training/validate_full")
|
||||||
# move_files("video_", "data/output/validate_full")
|
# move_files("video_", "data/output/validate_full")
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata_distill.csv \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--dataset_repeat 160 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.1-T2V-1.3B_full_distill" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--task "direct_distill" \
|
||||||
|
--extra_inputs "seed,rand_device,num_inference_steps,cfg_scale"
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth.utils.data import save_video, VideoData
|
||||||
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig("models/train/Wan2.1-T2V-1.3B_full_distill/epoch-1.safetensors"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
|
cfg_scale=1, num_inference_steps=4,
|
||||||
|
seed=0, tiled=True,
|
||||||
|
)
|
||||||
|
save_video(video, "video_distill_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5)
|
||||||
Reference in New Issue
Block a user