mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
bugfix
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
/data
|
/data
|
||||||
/models
|
/models
|
||||||
/scripts
|
/scripts
|
||||||
|
/diffusers
|
||||||
*.pkl
|
*.pkl
|
||||||
*.safetensors
|
*.safetensors
|
||||||
*.pth
|
*.pth
|
||||||
|
|||||||
@@ -21,12 +21,12 @@ image = pipe(
|
|||||||
num_inference_steps=50, embedded_guidance=3.5,
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
seed=0
|
seed=0
|
||||||
)
|
)
|
||||||
image.save(f"image_1.jpg")
|
image.save("image_1.jpg")
|
||||||
|
|
||||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||||
mask[200:400, 400:700] = 255
|
mask[200:400, 400:700] = 255
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask)
|
||||||
mask.save(f"image_mask.jpg")
|
mask.save("image_mask.jpg")
|
||||||
|
|
||||||
inpaint_image = image
|
inpaint_image = image
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ image = pipe(
|
|||||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||||
seed=4
|
seed=4
|
||||||
)
|
)
|
||||||
image.save(f"image_2_new.jpg")
|
image.save("image_2.jpg")
|
||||||
|
|
||||||
control_image = Annotator("canny")(image)
|
control_image = Annotator("canny")(image)
|
||||||
control_image.save("image_control.jpg")
|
control_image.save("image_control.jpg")
|
||||||
@@ -47,4 +47,4 @@ image = pipe(
|
|||||||
flex_control_image=control_image,
|
flex_control_image=control_image,
|
||||||
seed=4
|
seed=4
|
||||||
)
|
)
|
||||||
image.save(f"image_3_new.jpg")
|
image.save("image_3.jpg")
|
||||||
|
|||||||
@@ -32,12 +32,12 @@ image = pipe(
|
|||||||
num_inference_steps=50, embedded_guidance=3.5,
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
seed=0
|
seed=0
|
||||||
)
|
)
|
||||||
image.save(f"image_1.jpg")
|
image.save("image_1.jpg")
|
||||||
|
|
||||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||||
mask[200:400, 400:700] = 255
|
mask[200:400, 400:700] = 255
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask)
|
||||||
mask.save(f"image_mask.jpg")
|
mask.save("image_mask.jpg")
|
||||||
|
|
||||||
inpaint_image = image
|
inpaint_image = image
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ image = pipe(
|
|||||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||||
seed=4
|
seed=4
|
||||||
)
|
)
|
||||||
image.save(f"image_2_new.jpg")
|
image.save("image_2.jpg")
|
||||||
|
|
||||||
control_image = Annotator("canny")(image)
|
control_image = Annotator("canny")(image)
|
||||||
control_image.save("image_control.jpg")
|
control_image.save("image_control.jpg")
|
||||||
@@ -58,4 +58,4 @@ image = pipe(
|
|||||||
flex_control_image=control_image,
|
flex_control_image=control_image,
|
||||||
seed=4
|
seed=4
|
||||||
)
|
)
|
||||||
image.save(f"image_3_new.jpg")
|
image.save("image_3.jpg")
|
||||||
|
|||||||
@@ -14,12 +14,11 @@ pipe = FluxImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
pipe.enable_lora_magic()
|
|
||||||
state_dict = load_state_dict("models/train/FLUX.1-dev-LoRA-Encoder_full/epoch-0.safetensors")
|
state_dict = load_state_dict("models/train/FLUX.1-dev-LoRA-Encoder_full/epoch-0.safetensors")
|
||||||
pipe.lora_encoder.load_state_dict(state_dict)
|
pipe.lora_encoder.load_state_dict(state_dict)
|
||||||
|
|
||||||
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||||
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||||
|
|
||||||
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||||
image.save("image_FLUX.1-dev-LoRA-Encoder_full.jpg")
|
image.save("image_FLUX.1-dev-LoRA-Encoder_full.jpg")
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ pipe = FluxImagePipeline.from_pretrained(
|
|||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"),
|
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"),
|
||||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||||
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"),
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ pipe = FluxImagePipeline.from_pretrained(
|
|||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"),
|
ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"),
|
||||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
|
||||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from PIL import Image
|
|||||||
import torch
|
import torch
|
||||||
from modelscope import dataset_snapshot_download, snapshot_download
|
from modelscope import dataset_snapshot_download, snapshot_download
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.controlnets.processors import Annotator
|
from diffsynth.utils.controlnet import Annotator
|
||||||
|
|
||||||
allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"]
|
allow_file_pattern = ["sk_model.pth", "sk_model2.pth", "dpt_hybrid-midas-501f0c75.pt", "ControlNetHED.pth", "body_pose_model.pth", "hand_pose_model.pth", "facenet.pth", "scannet.pt"]
|
||||||
snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern)
|
snapshot_download("lllyasviel/Annotators", local_dir="models/Annotators", allow_file_pattern=allow_file_pattern)
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
|||||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_canny.csv \
|
||||||
--data_file_keys "image,blockwise_controlnet_image" \
|
--data_file_keys "image,blockwise_controlnet_image" \
|
||||||
--max_pixels 1048576 \
|
--max_pixels 1048576 \
|
||||||
--dataset_repeat 50 \
|
--dataset_repeat 400 \
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny:model.safetensors" \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-3 \
|
||||||
--num_epochs 2 \
|
--num_epochs 2 \
|
||||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
|
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Canny_full" \
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
|||||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_depth.csv \
|
||||||
--data_file_keys "image,blockwise_controlnet_image" \
|
--data_file_keys "image,blockwise_controlnet_image" \
|
||||||
--max_pixels 1048576 \
|
--max_pixels 1048576 \
|
||||||
--dataset_repeat 50 \
|
--dataset_repeat 400 \
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth:model.safetensors" \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-3 \
|
||||||
--num_epochs 2 \
|
--num_epochs 2 \
|
||||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
|
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Depth_full" \
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
|||||||
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
--dataset_metadata_path data/example_image_dataset/metadata_blockwise_controlnet_inpaint.csv \
|
||||||
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
--data_file_keys "image,blockwise_controlnet_image,blockwise_controlnet_inpaint_mask" \
|
||||||
--max_pixels 1048576 \
|
--max_pixels 1048576 \
|
||||||
--dataset_repeat 50 \
|
--dataset_repeat 400 \
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint:model.safetensors" \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-3 \
|
||||||
--num_epochs 2 \
|
--num_epochs 2 \
|
||||||
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
--remove_prefix_in_ckpt "pipe.blockwise_controlnet.models.0." \
|
||||||
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
|
--output_path "./models/train/Qwen-Image-Blockwise-ControlNet-Inpaint_full" \
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Without VRAM Management, 80G VRAM is not enough to run this example.
|
# Without VRAM Management, 80G VRAM is not enough to run this example.
|
||||||
# We recommend to use `examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py`.
|
# We recommend to use `examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py`.
|
||||||
|
# CPU Offload is enabled in this example.
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from diffsynth.utils.data import save_video, VideoData
|
from diffsynth.utils.data import save_video, VideoData
|
||||||
@@ -7,16 +8,27 @@ from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
|||||||
from modelscope import dataset_snapshot_download
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ save_video(video, "video_1_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
|||||||
# Replace
|
# Replace
|
||||||
snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B")
|
snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B")
|
||||||
lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.bfloat16, device="cuda")["state_dict"]
|
lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.bfloat16, device="cuda")["state_dict"]
|
||||||
|
lora_state_dict = {i: lora_state_dict[i].to(torch.bfloat16) for i in lora_state_dict}
|
||||||
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
||||||
input_image = Image.open("data/examples/wan/animate/replace_input_image.png")
|
input_image = Image.open("data/examples/wan/animate/replace_input_image.png")
|
||||||
animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4]
|
animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4]
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
# Without VRAM Management, 80G VRAM is not enough to run this example.
|
|
||||||
# We recommend to use `examples/wanvideo/model_inference_low_vram/Wan2.2-VACE-Fun-A14B.py`.
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from diffsynth.utils.data import save_video, VideoData
|
from diffsynth.utils.data import save_video, VideoData
|
||||||
|
|||||||
@@ -10,4 +10,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
|||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \
|
--output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \
|
||||||
--trainable_models "dit" \
|
--trainable_models "dit" \
|
||||||
--extra_inputs "input_image,end_image"
|
--extra_inputs "input_image,end_image" \
|
||||||
|
--initialize_model_on_cpu
|
||||||
@@ -10,4 +10,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
|||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Wan2.1-I2V-14B-480P_full" \
|
--output_path "./models/train/Wan2.1-I2V-14B-480P_full" \
|
||||||
--trainable_models "dit" \
|
--trainable_models "dit" \
|
||||||
--extra_inputs "input_image"
|
--extra_inputs "input_image" \
|
||||||
|
--initialize_model_on_cpu
|
||||||
@@ -12,4 +12,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
|||||||
--output_path "./models/train/Wan2.1-I2V-14B-720P_full" \
|
--output_path "./models/train/Wan2.1-I2V-14B-720P_full" \
|
||||||
--trainable_models "dit" \
|
--trainable_models "dit" \
|
||||||
--extra_inputs "input_image" \
|
--extra_inputs "input_image" \
|
||||||
--use_gradient_checkpointing_offload
|
--use_gradient_checkpointing_offload \
|
||||||
|
--initialize_model_on_cpu
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
--dataset_base_path data/example_video_dataset/wans2v \
|
--dataset_base_path data/example_video_dataset \
|
||||||
--dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \
|
--dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \
|
||||||
--data_file_keys "video,input_audio,s2v_pose_video" \
|
--data_file_keys "video,input_audio,s2v_pose_video" \
|
||||||
--height 448 \
|
--height 448 \
|
||||||
--width 832 \
|
--width 832 \
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
|||||||
--extra_inputs "vace_video,vace_reference_image" \
|
--extra_inputs "vace_video,vace_reference_image" \
|
||||||
--use_gradient_checkpointing_offload \
|
--use_gradient_checkpointing_offload \
|
||||||
--max_timestep_boundary 0.358 \
|
--max_timestep_boundary 0.358 \
|
||||||
--min_timestep_boundary 0
|
--min_timestep_boundary 0 \
|
||||||
|
--initialize_model_on_cpu
|
||||||
# boundary corresponds to timesteps [900, 1000]
|
# boundary corresponds to timesteps [900, 1000]
|
||||||
|
|
||||||
|
|
||||||
@@ -36,5 +37,6 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
|||||||
--extra_inputs "vace_video,vace_reference_image" \
|
--extra_inputs "vace_video,vace_reference_image" \
|
||||||
--use_gradient_checkpointing_offload \
|
--use_gradient_checkpointing_offload \
|
||||||
--max_timestep_boundary 1 \
|
--max_timestep_boundary 1 \
|
||||||
--min_timestep_boundary 0.358
|
--min_timestep_boundary 0.358 \
|
||||||
|
--initialize_model_on_cpu
|
||||||
# boundary corresponds to timesteps [0, 900]
|
# boundary corresponds to timesteps [0, 900]
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
accelerate launch examples/wanvideo/model_training/train.py \
|
# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA
|
||||||
|
# We tested on 8*80G GPUs
|
||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
--dataset_base_path data/example_video_dataset \
|
--dataset_base_path data/example_video_dataset \
|
||||||
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
--dataset_metadata_path data/example_video_dataset/metadata.csv \
|
||||||
--height 720 \
|
--height 720 \
|
||||||
--width 1280 \
|
--width 1280 \
|
||||||
--num_frames 49 \
|
|
||||||
--dataset_repeat 100 \
|
--dataset_repeat 100 \
|
||||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
--model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-4 \
|
||||||
@@ -14,4 +15,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
|||||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
--lora_rank 32 \
|
--lora_rank 32 \
|
||||||
--extra_inputs "input_image" \
|
--extra_inputs "input_image" \
|
||||||
--use_gradient_checkpointing_offload
|
--use_gradient_checkpointing_offload \
|
||||||
|
--initialize_model_on_cpu
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
--dataset_base_path data/example_video_dataset/wans2v \
|
--dataset_base_path data/example_video_dataset \
|
||||||
--dataset_metadata_path data/example_video_dataset/wans2v/metadata.csv \
|
--dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \
|
||||||
--data_file_keys "video,input_audio,s2v_pose_video" \
|
--data_file_keys "video,input_audio,s2v_pose_video" \
|
||||||
--height 448 \
|
--height 448 \
|
||||||
--width 832 \
|
--width 832 \
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ def wan_parser():
|
|||||||
parser.add_argument("--audio_processor_path", type=str, default=None, help="Path to the audio processor. If provided, the processor will be used for Wan2.2-S2V model.")
|
parser.add_argument("--audio_processor_path", type=str, default=None, help="Path to the audio processor. If provided, the processor will be used for Wan2.2-S2V model.")
|
||||||
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||||
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
|
||||||
|
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -165,7 +166,7 @@ if __name__ == "__main__":
|
|||||||
fp8_models=args.fp8_models,
|
fp8_models=args.fp8_models,
|
||||||
offload_models=args.offload_models,
|
offload_models=args.offload_models,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
device=accelerator.device,
|
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||||
max_timestep_boundary=args.max_timestep_boundary,
|
max_timestep_boundary=args.max_timestep_boundary,
|
||||||
min_timestep_boundary=args.min_timestep_boundary,
|
min_timestep_boundary=args.min_timestep_boundary,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,19 +5,29 @@ from diffsynth.core import load_state_dict
|
|||||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
pipe = WanVideoPipeline.from_pretrained(
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
||||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors")
|
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_high_noise_full/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu")
|
||||||
pipe.vace.load_state_dict(state_dict)
|
pipe.vace.load_state_dict(state_dict)
|
||||||
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors")
|
state_dict = load_state_dict("models/train/Wan2.2-VACE-Fun-A14B_low_noise_full/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu")
|
||||||
pipe.vace2.load_state_dict(state_dict)
|
pipe.vace2.load_state_dict(state_dict)
|
||||||
|
|
||||||
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
|
video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832)
|
||||||
|
|||||||
Reference in New Issue
Block a user