mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan-series models
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,6 +7,8 @@
|
||||
*.pt
|
||||
*.bin
|
||||
*.DS_Store
|
||||
*.msc
|
||||
*.mv
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
@@ -1,29 +1,266 @@
|
||||
MODEL_CONFIGS = [
|
||||
qwen_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
|
||||
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
|
||||
"model_name": "qwen_image_dit",
|
||||
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "8004730443f55db63092006dd9f7110e",
|
||||
"model_name": "qwen_image_text_encoder",
|
||||
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
|
||||
"model_name": "qwen_image_vae",
|
||||
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
|
||||
"model_name": "qwen_image_blockwise_controlnet",
|
||||
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
|
||||
"model_name": "qwen_image_blockwise_controlnet",
|
||||
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||
"extra_kwargs": {"additional_in_dim": 4}
|
||||
"extra_kwargs": {"additional_in_dim": 4},
|
||||
},
|
||||
]
|
||||
|
||||
wan_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
|
||||
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
|
||||
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
|
||||
"model_name": "wan_video_text_encoder",
|
||||
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
|
||||
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
|
||||
"model_name": "wan_video_vae",
|
||||
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||
"model_name": "wan_video_vap",
|
||||
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
||||
"model_hash": "5941c53e207d62f20f9025686193c40b",
|
||||
"model_name": "wan_video_image_encoder",
|
||||
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "dbd5ec76bbf977983f972c151d545389",
|
||||
"model_name": "wan_video_motion_controller",
|
||||
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "9269f8db9040a9d860eaca435be61814",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "349723183fc063b2bfc10bb2835cf677",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "70ddad9d3a133785da5ea371aae09504",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||
"model_name": "wan_video_vace",
|
||||
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||
"model_name": "wan_video_vace",
|
||||
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||
"model_name": "wan_video_animate_adapter",
|
||||
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "2267d489f0ceb9f21836532952852ee5",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "966cffdcc52f9c46c391768b27637614",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
|
||||
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
|
||||
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
|
||||
"model_name": "wan_video_vae",
|
||||
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||
},
|
||||
{
|
||||
# ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
|
||||
"model_hash": "06be60f3a4526586d8431cd038a71486",
|
||||
"model_name": "wans2v_audio_encoder",
|
||||
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
||||
}
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series
|
||||
|
||||
@@ -27,7 +27,8 @@ class ModelConfig:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||
|
||||
def download(self):
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "")
|
||||
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
|
||||
@@ -31,6 +31,8 @@ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = state_dict["state_dict"]
|
||||
elif "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
elif "model_state" in state_dict:
|
||||
state_dict = state_dict["model_state"]
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -28,7 +28,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if use_disk_map:
|
||||
state_dict = DiskMap(path, device)
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
# Why do we use `state_dict_converter`?
|
||||
|
||||
@@ -284,6 +284,16 @@ class BasePipeline(torch.nn.Module):
|
||||
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
||||
vram_management_enabled = True
|
||||
return vram_management_enabled
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
|
||||
class PipelineUnitGraph:
|
||||
|
||||
901
diffsynth/models/longcat_video_dit.py
Normal file
901
diffsynth/models/longcat_video_dit.py
Normal file
@@ -0,0 +1,901 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.amp as amp
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
class RMSNorm_FP32(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def broadcat(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatentation"
|
||||
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
||||
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
||||
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
cp_split_hw=None
|
||||
):
|
||||
"""Rotary positional embedding for 3D
|
||||
Reference : https://blog.eleuther.ai/rotary-embeddings/
|
||||
Paper: https://arxiv.org/pdf/2104.09864.pdf
|
||||
Args:
|
||||
dim: Dimension of embedding
|
||||
base: Base value for exponential
|
||||
"""
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
|
||||
self.cp_split_hw = cp_split_hw
|
||||
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
|
||||
self.base = 10000
|
||||
self.freqs_dict = {}
|
||||
|
||||
def register_grid_size(self, grid_size):
|
||||
if grid_size not in self.freqs_dict:
|
||||
self.freqs_dict.update({
|
||||
grid_size: self.precompute_freqs_cis_3d(grid_size)
|
||||
})
|
||||
|
||||
def precompute_freqs_cis_3d(self, grid_size):
|
||||
num_frames, height, width = grid_size
|
||||
dim_t = self.head_dim - 4 * (self.head_dim // 6)
|
||||
dim_h = 2 * (self.head_dim // 6)
|
||||
dim_w = 2 * (self.head_dim // 6)
|
||||
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
|
||||
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
||||
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
||||
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
||||
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
# (T H W D)
|
||||
freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# with torch.no_grad():
|
||||
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
|
||||
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
|
||||
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||
|
||||
return freqs
|
||||
|
||||
def forward(self, q, k, grid_size):
|
||||
"""3D RoPE.
|
||||
|
||||
Args:
|
||||
query: [B, head, seq, head_dim]
|
||||
key: [B, head, seq, head_dim]
|
||||
Returns:
|
||||
query and key with the same shape as input.
|
||||
"""
|
||||
|
||||
if grid_size not in self.freqs_dict:
|
||||
self.register_grid_size(grid_size)
|
||||
|
||||
freqs_cis = self.freqs_dict[grid_size].to(q.device)
|
||||
q_, k_ = q.float(), k.float()
|
||||
freqs_cis = freqs_cis.float().to(q.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
q_ = (q_ * cos) + (rotate_half(q_) * sin)
|
||||
k_ = (k_ * cos) + (rotate_half(k_) * sin)
|
||||
|
||||
return q_.type_as(q), k_.type_as(k)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = False,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params: dict = None,
|
||||
cp_split_hw: Optional[List[int]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flashattn3 = enable_flashattn3
|
||||
self.enable_flashattn2 = enable_flashattn2
|
||||
self.enable_xformers = enable_xformers
|
||||
self.enable_bsa = enable_bsa
|
||||
self.bsa_params = bsa_params
|
||||
self.cp_split_hw = cp_split_hw
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.rope_3d = RotaryPositionalEmbedding(
|
||||
self.head_dim,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
|
||||
def _process_attn(self, q, k, v, shape):
|
||||
q = rearrange(q, "B H S D -> B S (H D)")
|
||||
k = rearrange(k, "B H S D -> B S (H D)")
|
||||
v = rearrange(v, "B H S D -> B S (H D)")
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if return_kv:
|
||||
k_cache, v_cache = k.clone(), v.clone()
|
||||
|
||||
q, k = self.rope_3d(q, k, shape)
|
||||
|
||||
# cond mode
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||
# process the condition tokens
|
||||
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
|
||||
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
|
||||
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
|
||||
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
|
||||
# process the noise tokens
|
||||
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
|
||||
x_noise = self._process_attn(q_noise, k, v, shape)
|
||||
# merge x_cond and x_noise
|
||||
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
|
||||
else:
|
||||
x = self._process_attn(q, k, v, shape)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||
x = self.proj(x)
|
||||
|
||||
if return_kv:
|
||||
return x, (k_cache, v_cache)
|
||||
else:
|
||||
return x
|
||||
|
||||
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
T, H, W = shape
|
||||
k_cache, v_cache = kv_cache
|
||||
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
|
||||
if k_cache.shape[0] == 1:
|
||||
k_cache = k_cache.repeat(B, 1, 1, 1)
|
||||
v_cache = v_cache.repeat(B, 1, 1, 1)
|
||||
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
k_full = torch.cat([k_cache, k], dim=2).contiguous()
|
||||
v_full = torch.cat([v_cache, v], dim=2).contiguous()
|
||||
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
|
||||
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
|
||||
q = q_padding[:, :, -N:].contiguous()
|
||||
|
||||
x = self._process_attn(q, k_full, v_full, shape)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
enable_flashattn3=False,
|
||||
enable_flashattn2=False,
|
||||
enable_xformers=False,
|
||||
):
|
||||
super(MultiHeadCrossAttention, self).__init__()
|
||||
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = nn.Linear(dim, dim)
|
||||
self.kv_linear = nn.Linear(dim, dim * 2)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
|
||||
self.enable_flashattn3 = enable_flashattn3
|
||||
self.enable_flashattn2 = enable_flashattn2
|
||||
self.enable_xformers = enable_xformers
|
||||
|
||||
def _process_cross_attn(self, x, cond, kv_seqlen):
|
||||
B, N, C = x.shape
|
||||
assert C == self.dim and cond.shape[2] == self.dim
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
q = rearrange(q, "B S H D -> B S (H D)")
|
||||
k = rearrange(k, "B S H D -> B S (H D)")
|
||||
v = rearrange(v, "B S H D -> B S (H D)")
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
|
||||
x = x.view(B, -1, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
|
||||
"""
|
||||
x: [B, N, C]
|
||||
cond: [B, M, C]
|
||||
"""
|
||||
if num_cond_latents is None or num_cond_latents == 0:
|
||||
return self._process_cross_attn(x, cond, kv_seqlen)
|
||||
else:
|
||||
B, N, C = x.shape
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
assert shape is not None, "SHOULD pass in the shape"
|
||||
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
|
||||
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
|
||||
output = torch.cat([
|
||||
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
|
||||
output_noise
|
||||
], dim=1).contiguous()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class LayerNorm_FP32(nn.LayerNorm):
|
||||
def __init__(self, dim, eps, elementwise_affine):
|
||||
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
origin_dtype = inputs.dtype
|
||||
out = F.layer_norm(
|
||||
inputs.float(),
|
||||
self.normalized_shape,
|
||||
None if self.weight is None else self.weight.float(),
|
||||
None if self.bias is None else self.bias.float() ,
|
||||
self.eps
|
||||
).to(origin_dtype)
|
||||
return out
|
||||
|
||||
|
||||
def modulate_fp32(norm_func, x, shift, scale):
|
||||
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
|
||||
# ensure the modulation params be fp32
|
||||
assert shift.dtype == torch.float32, scale.dtype == torch.float32
|
||||
dtype = x.dtype
|
||||
x = norm_func(x.to(torch.float32))
|
||||
x = x * (scale + 1) + shift
|
||||
x = x.to(dtype)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer_FP32(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_patch = num_patch
|
||||
self.out_channels = out_channels
|
||||
self.adaln_tembed_dim = adaln_tembed_dim
|
||||
|
||||
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, t, latent_shape):
|
||||
# timestep shape: [B, T, C]
|
||||
assert t.dtype == torch.float32
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast('cuda', dtype=torch.float32):
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.dim = dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, t_embed_dim, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.t_embed_dim = t_embed_dim
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||
freqs = freqs.to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
if t_freq.dtype != dtype:
|
||||
t_freq = t_freq.to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.y_proj = nn.Sequential(
|
||||
nn.Linear(in_channels, hidden_size, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
B, _, N, C = caption.shape
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class PatchEmbed3D(nn.Module):
|
||||
"""Video to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: (2,4,4).
|
||||
in_chans (int): Number of input video channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 4, 4),
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.flatten = flatten
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, D, H, W = x.size()
|
||||
if W % self.patch_size[2] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||
if H % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||
if D % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||
|
||||
B, C, T, H, W = x.shape
|
||||
x = self.proj(x) # (B C T H W)
|
||||
if self.norm is not None:
|
||||
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
class LongCatSingleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: int,
|
||||
adaln_tembed_dim: int,
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = False,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params=None,
|
||||
cp_split_hw=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# scale and gate modulation
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
|
||||
|
||||
self.attn = Attention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
enable_bsa=enable_bsa,
|
||||
bsa_params=bsa_params,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
)
|
||||
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
|
||||
|
||||
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
|
||||
"""
|
||||
x: [B, N, C]
|
||||
y: [1, N_valid_tokens, C]
|
||||
t: [B, T, C_t]
|
||||
y_seqlen: [B]; type of a list
|
||||
latent_shape: latent shape of a single item
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
|
||||
# self attn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
|
||||
|
||||
if kv_cache is not None:
|
||||
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
|
||||
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
|
||||
else:
|
||||
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
|
||||
|
||||
if return_kv:
|
||||
x_s, kv_cache = attn_outputs
|
||||
else:
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
# cross attn
|
||||
if not skip_crs_attn:
|
||||
if kv_cache is not None:
|
||||
num_cond_latents = None
|
||||
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
|
||||
|
||||
# ffn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
if return_kv:
|
||||
return x, kv_cache
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
hidden_size: int = 4096,
|
||||
depth: int = 48,
|
||||
num_heads: int = 32,
|
||||
caption_channels: int = 4096,
|
||||
mlp_ratio: int = 4,
|
||||
adaln_tembed_dim: int = 512,
|
||||
frequency_embedding_size: int = 256,
|
||||
# default params
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
# attention config
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = True,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
|
||||
cp_split_hw: Optional[List[int]] = [1, 1],
|
||||
text_tokens_zero_pad: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.cp_split_hw = cp_split_hw
|
||||
|
||||
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels,
|
||||
hidden_size=hidden_size,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
LongCatSingleStreamBlock(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
adaln_tembed_dim=adaln_tembed_dim,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
enable_bsa=enable_bsa,
|
||||
bsa_params=bsa_params,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer_FP32(
|
||||
hidden_size,
|
||||
np.prod(self.patch_size),
|
||||
out_channels,
|
||||
adaln_tembed_dim,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.text_tokens_zero_pad = text_tokens_zero_pad
|
||||
|
||||
self.lora_dict = {}
|
||||
self.active_loras = []
|
||||
|
||||
def enable_loras(self, lora_key_list=[]):
|
||||
self.disable_all_loras()
|
||||
|
||||
module_loras = {} # {module_name: [lora1, lora2, ...]}
|
||||
model_device = next(self.parameters()).device
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
|
||||
for lora_key in lora_key_list:
|
||||
if lora_key in self.lora_dict:
|
||||
for lora in self.lora_dict[lora_key].loras:
|
||||
lora.to(model_device, dtype=model_dtype, non_blocking=True)
|
||||
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
|
||||
if module_name not in module_loras:
|
||||
module_loras[module_name] = []
|
||||
module_loras[module_name].append(lora)
|
||||
self.active_loras.append(lora_key)
|
||||
|
||||
for module_name, loras in module_loras.items():
|
||||
module = self._get_module_by_name(module_name)
|
||||
if not hasattr(module, 'org_forward'):
|
||||
module.org_forward = module.forward
|
||||
module.forward = self._create_multi_lora_forward(module, loras)
|
||||
|
||||
def _create_multi_lora_forward(self, module, loras):
|
||||
def multi_lora_forward(x, *args, **kwargs):
|
||||
weight_dtype = x.dtype
|
||||
org_output = module.org_forward(x, *args, **kwargs)
|
||||
|
||||
total_lora_output = 0
|
||||
for lora in loras:
|
||||
if lora.use_lora:
|
||||
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
|
||||
lx = lora.lora_up(lx)
|
||||
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
|
||||
total_lora_output += lora_output
|
||||
|
||||
return org_output + total_lora_output
|
||||
|
||||
return multi_lora_forward
|
||||
|
||||
def _get_module_by_name(self, module_name):
|
||||
try:
|
||||
module = self
|
||||
for part in module_name.split('.'):
|
||||
module = getattr(module, part)
|
||||
return module
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
|
||||
|
||||
def disable_all_loras(self):
|
||||
for name, module in self.named_modules():
|
||||
if hasattr(module, 'org_forward'):
|
||||
module.forward = module.org_forward
|
||||
delattr(module, 'org_forward')
|
||||
|
||||
for lora_key, lora_network in self.lora_dict.items():
|
||||
for lora in lora_network.loras:
|
||||
lora.to("cpu")
|
||||
|
||||
self.active_loras.clear()
|
||||
|
||||
def enable_bsa(self,):
|
||||
for block in self.blocks:
|
||||
block.attn.enable_bsa = True
|
||||
|
||||
def disable_bsa(self,):
|
||||
for block in self.blocks:
|
||||
block.attn.enable_bsa = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
num_cond_latents=0,
|
||||
return_kv=False,
|
||||
kv_cache_dict={},
|
||||
skip_crs_attn=False,
|
||||
offload_kv_cache=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
|
||||
B, _, T, H, W = hidden_states.shape
|
||||
|
||||
N_t = T // self.patch_size[0]
|
||||
N_h = H // self.patch_size[1]
|
||||
N_w = W // self.patch_size[2]
|
||||
|
||||
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
|
||||
|
||||
# expand the shape of timestep from [B] to [B, T]
|
||||
if len(timestep.shape) == 1:
|
||||
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
|
||||
timestep[:, :num_cond_latents] = 0
|
||||
|
||||
dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
|
||||
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
|
||||
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
|
||||
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
|
||||
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
|
||||
else:
|
||||
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
|
||||
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
|
||||
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
|
||||
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
|
||||
|
||||
# blocks
|
||||
kv_cache_dict_ret = {}
|
||||
for i, block in enumerate(self.blocks):
|
||||
block_outputs = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=hidden_states,
|
||||
y=encoder_hidden_states,
|
||||
t=t,
|
||||
y_seqlen=y_seqlens,
|
||||
latent_shape=(N_t, N_h, N_w),
|
||||
num_cond_latents=num_cond_latents,
|
||||
return_kv=return_kv,
|
||||
kv_cache=kv_cache_dict.get(i, None),
|
||||
skip_crs_attn=skip_crs_attn,
|
||||
)
|
||||
|
||||
if return_kv:
|
||||
hidden_states, kv_cache = block_outputs
|
||||
if offload_kv_cache:
|
||||
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
|
||||
else:
|
||||
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
|
||||
else:
|
||||
hidden_states = block_outputs
|
||||
|
||||
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
|
||||
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
|
||||
|
||||
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
if return_kv:
|
||||
return hidden_states, kv_cache_dict_ret
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unpatchify(self, x, N_t, N_h, N_w):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): of shape [B, N, C]
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): of shape [B, C_out, T, H, W]
|
||||
"""
|
||||
T_p, H_p, W_p = self.patch_size
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
|
||||
N_t=N_t,
|
||||
N_h=N_h,
|
||||
N_w=N_w,
|
||||
T_p=T_p,
|
||||
H_p=H_p,
|
||||
W_p=W_p,
|
||||
C_out=self.out_channels,
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return LongCatVideoTransformer3DModelDictConverter()
|
||||
|
||||
|
||||
class LongCatVideoTransformer3DModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
670
diffsynth/models/wan_video_animate_adapter.py
Normal file
670
diffsynth/models/wan_video_animate_adapter.py
Normal file
@@ -0,0 +1,670 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
MEMORY_LAYOUT = {
|
||||
"flash": (
|
||||
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
||||
lambda x: x,
|
||||
),
|
||||
"torch": (
|
||||
lambda x: x.transpose(1, 2),
|
||||
lambda x: x.transpose(1, 2),
|
||||
),
|
||||
"vanilla": (
|
||||
lambda x: x.transpose(1, 2),
|
||||
lambda x: x.transpose(1, 2),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
mode="torch",
|
||||
drop_rate=0,
|
||||
attn_mask=None,
|
||||
causal=False,
|
||||
max_seqlen_q=None,
|
||||
batch_size=1,
|
||||
):
|
||||
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
||||
|
||||
if mode == "torch":
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(q.dtype)
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
||||
|
||||
x = post_attn_layout(x)
|
||||
b, s, a, d = x.shape
|
||||
out = x.reshape(b, s, -1)
|
||||
return out
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
||||
class FaceEncoder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||
|
||||
self.out_proj = nn.Linear(1024, hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
b, c, t = x.shape
|
||||
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.out_proj(x)
|
||||
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
||||
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
return x_local
|
||||
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
def get_norm_layer(norm_layer):
|
||||
"""
|
||||
Get the normalization layer.
|
||||
|
||||
Args:
|
||||
norm_layer (str): The type of normalization layer.
|
||||
|
||||
Returns:
|
||||
norm_layer (nn.Module): The normalization layer.
|
||||
"""
|
||||
if norm_layer == "layer":
|
||||
return nn.LayerNorm
|
||||
elif norm_layer == "rms":
|
||||
return RMSNorm
|
||||
else:
|
||||
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||
|
||||
|
||||
class FaceAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
num_adapter_layers: int = 1,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_dim
|
||||
self.heads_num = heads_num
|
||||
self.fuser_blocks = nn.ModuleList(
|
||||
[
|
||||
FaceBlock(
|
||||
self.hidden_size,
|
||||
self.heads_num,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(num_adapter_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_embed: torch.Tensor,
|
||||
idx: int,
|
||||
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
||||
|
||||
|
||||
|
||||
class FaceBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qk_scale: float = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.deterministic = False
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
||||
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||
self.q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
self.k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
|
||||
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_vec: torch.Tensor,
|
||||
motion_mask: Optional[torch.Tensor] = None,
|
||||
use_context_parallel=False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
B, T, N, C = motion_vec.shape
|
||||
T_comp = T
|
||||
|
||||
x_motion = self.pre_norm_motion(motion_vec)
|
||||
x_feat = self.pre_norm_feat(x)
|
||||
|
||||
kv = self.linear1_kv(x_motion)
|
||||
q = self.linear1_q(x_feat)
|
||||
|
||||
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
||||
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
||||
|
||||
# Apply QK-Norm if needed.
|
||||
q = self.q_norm(q).to(v)
|
||||
k = self.k_norm(k).to(v)
|
||||
|
||||
k = rearrange(k, "B L N H D -> (B L) H N D")
|
||||
v = rearrange(v, "B L N H D -> (B L) H N D")
|
||||
|
||||
q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp)
|
||||
# Compute attention.
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp)
|
||||
|
||||
output = self.linear2(attn)
|
||||
|
||||
if motion_mask is not None:
|
||||
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def custom_qr(input_tensor):
|
||||
original_dtype = input_tensor.dtype
|
||||
if original_dtype == torch.bfloat16:
|
||||
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
||||
return q.to(original_dtype), r.to(original_dtype)
|
||||
return torch.linalg.qr(input_tensor)
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, minor, in_h, in_w = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
||||
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
||||
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
||||
|
||||
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
||||
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
||||
return out[:, :, ::down_y, ::down_x]
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
k /= k.sum()
|
||||
return k
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
return out
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
return upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
|
||||
class ScaledLeakyReLU(nn.Module):
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super().__init__()
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, input):
|
||||
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
else:
|
||||
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
||||
bias=bias and not activate))
|
||||
|
||||
if activate:
|
||||
if bias:
|
||||
layers.append(FusedLeakyReLU(out_channel))
|
||||
else:
|
||||
layers.append(ScaledLeakyReLU(0.2))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||
|
||||
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EncoderApp(nn.Module):
|
||||
def __init__(self, size, w_dim=512):
|
||||
super(EncoderApp, self).__init__()
|
||||
|
||||
channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256,
|
||||
128: 128,
|
||||
256: 64,
|
||||
512: 32,
|
||||
1024: 16
|
||||
}
|
||||
|
||||
self.w_dim = w_dim
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.convs.append(ConvLayer(3, channels[size], 1))
|
||||
|
||||
in_channel = channels[size]
|
||||
for i in range(log_size, 2, -1):
|
||||
out_channel = channels[2 ** (i - 1)]
|
||||
self.convs.append(ResBlock(in_channel, out_channel))
|
||||
in_channel = out_channel
|
||||
|
||||
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
res = []
|
||||
h = x
|
||||
for conv in self.convs:
|
||||
h = conv(h)
|
||||
res.append(h)
|
||||
|
||||
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, size, dim=512, dim_motion=20):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
# appearance netmork
|
||||
self.net_app = EncoderApp(size, dim)
|
||||
|
||||
# motion network
|
||||
fc = [EqualLinear(dim, dim)]
|
||||
for i in range(3):
|
||||
fc.append(EqualLinear(dim, dim))
|
||||
|
||||
fc.append(EqualLinear(dim, dim_motion))
|
||||
self.fc = nn.Sequential(*fc)
|
||||
|
||||
def enc_app(self, x):
|
||||
h_source = self.net_app(x)
|
||||
return h_source
|
||||
|
||||
def enc_motion(self, x):
|
||||
h, _ = self.net_app(x)
|
||||
h_motion = self.fc(h)
|
||||
return h_motion
|
||||
|
||||
|
||||
class Direction(nn.Module):
|
||||
def __init__(self, motion_dim):
|
||||
super(Direction, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
weight = self.weight + 1e-8
|
||||
Q, R = custom_qr(weight)
|
||||
if input is None:
|
||||
return Q
|
||||
else:
|
||||
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
||||
out = torch.matmul(input_diag, Q.T)
|
||||
out = torch.sum(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
class Synthesis(nn.Module):
|
||||
def __init__(self, motion_dim):
|
||||
super(Synthesis, self).__init__()
|
||||
self.direction = Direction(motion_dim)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, size, style_dim=512, motion_dim=20):
|
||||
super().__init__()
|
||||
|
||||
self.enc = Encoder(size, style_dim, motion_dim)
|
||||
self.dec = Synthesis(motion_dim)
|
||||
|
||||
def get_motion(self, img):
|
||||
#motion_feat = self.enc.enc_motion(img)
|
||||
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
||||
motion = self.dec.direction(motion_feat)
|
||||
return motion
|
||||
|
||||
|
||||
class WanAnimateAdapter(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
|
||||
self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5)
|
||||
self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4)
|
||||
|
||||
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
||||
pose_latents = self.pose_patch_embedding(pose_latents)
|
||||
x[:, :, 1:] += pose_latents
|
||||
|
||||
b,c,T,h,w = face_pixel_values.shape
|
||||
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
||||
|
||||
encode_bs = 8
|
||||
face_pixel_values_tmp = []
|
||||
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
||||
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
||||
|
||||
motion_vec = torch.cat(face_pixel_values_tmp)
|
||||
|
||||
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
||||
motion_vec = self.face_encoder(motion_vec)
|
||||
|
||||
B, L, H, C = motion_vec.shape
|
||||
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
||||
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
||||
return x, motion_vec
|
||||
|
||||
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
|
||||
if block_idx % 5 == 0:
|
||||
adapter_args = [x, motion_vec, motion_masks, False]
|
||||
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
|
||||
x = residual_out + x
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanAnimateAdapterStateDictConverter()
|
||||
|
||||
|
||||
class WanAnimateAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
206
diffsynth/models/wan_video_camera_controller.py
Normal file
206
diffsynth/models/wan_video_camera_controller.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import os
|
||||
from typing_extensions import Literal
|
||||
|
||||
class SimpleAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
|
||||
super(SimpleAdapter, self).__init__()
|
||||
|
||||
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
||||
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
|
||||
|
||||
# Convolution: reduce spatial dimensions by a factor
|
||||
# of 2 (without overlap)
|
||||
self.conv = nn.Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
|
||||
|
||||
# Residual blocks for feature extraction
|
||||
self.residual_blocks = nn.Sequential(
|
||||
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Reshape to merge the frame dimension into batch
|
||||
bs, c, f, h, w = x.size()
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
||||
|
||||
# Pixel Unshuffle operation
|
||||
x_unshuffled = self.pixel_unshuffle(x)
|
||||
|
||||
# Convolution operation
|
||||
x_conv = self.conv(x_unshuffled)
|
||||
|
||||
# Feature extraction with residual blocks
|
||||
out = self.residual_blocks(x_conv)
|
||||
|
||||
# Reshape to restore original bf dimension
|
||||
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
||||
|
||||
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return out
|
||||
|
||||
def process_camera_coordinates(
|
||||
self,
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
length: int,
|
||||
height: int,
|
||||
width: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
if origin is None:
|
||||
origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
coordinates = generate_camera_coordinates(direction, length, speed, origin)
|
||||
plucker_embedding = process_pose_file(coordinates, width, height)
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += residual
|
||||
return out
|
||||
|
||||
class Camera(object):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
def __init__(self, entry):
|
||||
fx, fy, cx, cy = entry[1:5]
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
||||
w2c_mat_4x4 = np.eye(4)
|
||||
w2c_mat_4x4[:3, :] = w2c_mat
|
||||
self.w2c_mat = w2c_mat_4x4
|
||||
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
||||
|
||||
def get_relative_pose(cam_params):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
||||
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
||||
cam_to_origin = 0
|
||||
target_cam_c2w = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, -cam_to_origin],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]
|
||||
])
|
||||
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
||||
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
||||
ret_poses = np.array(ret_poses, dtype=np.float32)
|
||||
return ret_poses
|
||||
|
||||
def custom_meshgrid(*args):
|
||||
# torch>=2.0.0 only
|
||||
return torch.meshgrid(*args, indexing='ij')
|
||||
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# c2w: B, V, 4, 4
|
||||
# K: B, V, 4
|
||||
|
||||
B = K.shape[0]
|
||||
|
||||
j, i = custom_meshgrid(
|
||||
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
||||
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
||||
)
|
||||
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
||||
|
||||
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
||||
|
||||
zs = torch.ones_like(i) # [B, HxW]
|
||||
xs = (i - cx) / fx * zs
|
||||
ys = (j - cy) / fy * zs
|
||||
zs = zs.expand_as(ys)
|
||||
|
||||
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
||||
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
||||
|
||||
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
|
||||
def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||
if return_poses:
|
||||
return cam_params
|
||||
else:
|
||||
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
||||
|
||||
sample_wh_ratio = width / height
|
||||
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
||||
|
||||
if pose_wh_ratio > sample_wh_ratio:
|
||||
resized_ori_w = height * pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fx = resized_ori_w * cam_param.fx / width
|
||||
else:
|
||||
resized_ori_h = width / pose_wh_ratio
|
||||
for cam_param in cam_params:
|
||||
cam_param.fy = resized_ori_h * cam_param.fy / height
|
||||
|
||||
intrinsic = np.asarray([[cam_param.fx * width,
|
||||
cam_param.fy * height,
|
||||
cam_param.cx * width,
|
||||
cam_param.cy * height]
|
||||
for cam_param in cam_params], dtype=np.float32)
|
||||
|
||||
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
||||
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
||||
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
||||
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
def generate_camera_coordinates(
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"],
|
||||
length: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
coordinates = [list(origin)]
|
||||
while len(coordinates) < length:
|
||||
coor = coordinates[-1].copy()
|
||||
if "Left" in direction:
|
||||
coor[9] += speed
|
||||
if "Right" in direction:
|
||||
coor[9] -= speed
|
||||
if "Up" in direction:
|
||||
coor[13] += speed
|
||||
if "Down" in direction:
|
||||
coor[13] -= speed
|
||||
if "In" in direction:
|
||||
coor[18] -= speed
|
||||
if "Out" in direction:
|
||||
coor[18] += speed
|
||||
coordinates.append(coor)
|
||||
return coordinates
|
||||
772
diffsynth/models/wan_video_dit.py
Normal file
772
diffsynth/models/wan_video_dit.py
Normal file
@@ -0,0 +1,772 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTN_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
|
||||
|
||||
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
|
||||
if compatibility_mode:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
elif FLASH_ATTN_3_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||
if isinstance(x,tuple):
|
||||
x = x[0]
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
elif FLASH_ATTN_2_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||
x = flash_attn.flash_attn_func(q, k, v)
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
elif SAGE_ATTN_AVAILABLE:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = sageattn(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return (x * (1 + scale) + shift)
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
|
||||
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
# 3d rope precompute
|
||||
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
||||
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
||||
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
||||
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
# 1d rope precompute
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
||||
[: (dim // 2)].double() / dim))
|
||||
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, num_heads):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, q, k, v):
|
||||
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
x = self.attn(q, k, v)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
self.has_image_input = has_image_input
|
||||
if has_image_input:
|
||||
self.k_img = nn.Linear(dim, dim)
|
||||
self.v_img = nn.Linear(dim, dim)
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
if self.has_image_input:
|
||||
img = y[:, :257]
|
||||
ctx = y[:, 257:]
|
||||
else:
|
||||
ctx = y
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(ctx))
|
||||
v = self.v(ctx)
|
||||
x = self.attn(q, k, v)
|
||||
if self.has_image_input:
|
||||
k_img = self.norm_k_img(self.k_img(img))
|
||||
v_img = self.v_img(img)
|
||||
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
|
||||
x = x + y
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class GateModule(nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, gate, residual):
|
||||
return x + gate * residual
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.ffn_dim = ffn_dim
|
||||
|
||||
self.self_attn = SelfAttention(dim, num_heads, eps)
|
||||
self.cross_attn = CrossAttention(
|
||||
dim, num_heads, eps, has_image_input=has_image_input)
|
||||
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
||||
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
||||
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
self.gate = GateModule()
|
||||
|
||||
def forward(self, x, context, t_mod, freqs):
|
||||
has_seq = len(t_mod.shape) == 4
|
||||
chunk_dim = 2 if has_seq else 1
|
||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
|
||||
if has_seq:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
|
||||
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
|
||||
)
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||
return x
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Sequential(
|
||||
nn.LayerNorm(in_dim),
|
||||
nn.Linear(in_dim, in_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(in_dim, out_dim),
|
||||
nn.LayerNorm(out_dim)
|
||||
)
|
||||
self.has_pos_emb = has_pos_emb
|
||||
if has_pos_emb:
|
||||
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.has_pos_emb:
|
||||
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, t_mod):
|
||||
if len(t_mod.shape) == 3:
|
||||
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
|
||||
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
|
||||
else:
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
return x
|
||||
|
||||
|
||||
class WanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: Tuple[int, int, int],
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
has_image_pos_emb: bool = False,
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
seperated_timestep: bool = False,
|
||||
require_vae_embedding: bool = True,
|
||||
require_clip_embedding: bool = True,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.in_dim = in_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
self.blocks = nn.ModuleList([
|
||||
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
head_dim = dim // num_heads
|
||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
|
||||
if has_ref_conv:
|
||||
self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2))
|
||||
self.has_image_pos_emb = has_image_pos_emb
|
||||
self.has_ref_conv = has_ref_conv
|
||||
if add_control_adapter:
|
||||
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
||||
else:
|
||||
self.control_adapter = None
|
||||
|
||||
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
|
||||
x = self.patch_embedding(x)
|
||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||
y_camera = self.control_adapter(control_camera_latents_input)
|
||||
x = [u + v for u, v in zip(x, y_camera)]
|
||||
x = x[0].unsqueeze(0)
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
||||
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
x = self.head(x, t)
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanModelStateDictConverter()
|
||||
|
||||
|
||||
class WanModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
if hash_state_dict_keys(state_dict_) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
"text_len": 512,
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"window_size": (-1, -1),
|
||||
"qk_norm": True,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict_) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||
state_dict = {name: param for name, param in state_dict.items() if name.split(".")[0] not in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
# 1.3B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
# 14B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||
# 1.3B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||
# 14B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
||||
# 1.3B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
||||
# 14B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
|
||||
# Wan-AI/Wan2.2-TI2V-5B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 3072,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 48,
|
||||
"num_heads": 24,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"seperated_timestep": True,
|
||||
"require_clip_embedding": False,
|
||||
"require_vae_embedding": False,
|
||||
"fuse_vae_embedding_in_latents": True,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||
# Wan-AI/Wan2.2-I2V-A14B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5":
|
||||
# Wan2.2-Fun-A14B-Control
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 52,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1":
|
||||
# Wan2.2-Fun-A14B-Control-Camera
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
594
diffsynth/models/wan_video_dit_s2v.py
Normal file
594
diffsynth/models/wan_video_dit_s2v.py
Normal file
@@ -0,0 +1,594 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
def torch_dfs(model: nn.Module, parent_name='root'):
|
||||
module_names, modules = [], []
|
||||
current_name = parent_name if parent_name else 'root'
|
||||
module_names.append(current_name)
|
||||
modules.append(model)
|
||||
|
||||
for name, child in model.named_children():
|
||||
if parent_name:
|
||||
child_name = f'{parent_name}.{name}'
|
||||
else:
|
||||
child_name = name
|
||||
child_modules, child_names = torch_dfs(child, child_name)
|
||||
module_names += child_names
|
||||
modules += child_modules
|
||||
return modules, module_names
|
||||
|
||||
|
||||
def rope_precompute(x, grid_sizes, freqs, start=None):
|
||||
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
if type(freqs) is list:
|
||||
trainable_freqs = freqs[1]
|
||||
freqs = freqs[0]
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64))
|
||||
seq_bucket = [0]
|
||||
if not type(grid_sizes) is list:
|
||||
grid_sizes = [grid_sizes]
|
||||
for g in grid_sizes:
|
||||
if not type(g) is list:
|
||||
g = [torch.zeros_like(g), g]
|
||||
batch_size = g[0].shape[0]
|
||||
for i in range(batch_size):
|
||||
if start is None:
|
||||
f_o, h_o, w_o = g[0][i]
|
||||
else:
|
||||
f_o, h_o, w_o = start[i]
|
||||
|
||||
f, h, w = g[1][i]
|
||||
t_f, t_h, t_w = g[2][i]
|
||||
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
|
||||
seq_len = int(seq_f * seq_h * seq_w)
|
||||
if seq_len > 0:
|
||||
if t_f > 0:
|
||||
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item()
|
||||
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
|
||||
if f_o >= 0:
|
||||
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
|
||||
else:
|
||||
f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
|
||||
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
|
||||
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
|
||||
|
||||
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
|
||||
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
|
||||
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
|
||||
|
||||
freqs_i = torch.cat(
|
||||
[
|
||||
freqs_0.expand(seq_f, seq_h, seq_w, -1),
|
||||
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
|
||||
],
|
||||
dim=-1
|
||||
).reshape(seq_len, 1, -1)
|
||||
elif t_f < 0:
|
||||
freqs_i = trainable_freqs.unsqueeze(1)
|
||||
# apply rotary embedding
|
||||
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
|
||||
seq_bucket.append(seq_bucket[-1] + seq_len)
|
||||
return output
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MotionEncoder_tc(nn.Module):
|
||||
|
||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.need_global = need_global
|
||||
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
|
||||
if need_global:
|
||||
self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
|
||||
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
|
||||
|
||||
if need_global:
|
||||
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x_ori = x.clone()
|
||||
b, c, t = x.shape
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
if not self.need_global:
|
||||
return x_local
|
||||
|
||||
x = self.conv1_global(x_ori)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, 'b t c -> b c t')
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, 'b c t -> b t c')
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.final_linear(x)
|
||||
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||
|
||||
return x, x_local
|
||||
|
||||
|
||||
class FramePackMotioner(nn.Module):
|
||||
|
||||
def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
||||
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
||||
self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)
|
||||
|
||||
self.inner_dim = inner_dim
|
||||
self.num_heads = num_heads
|
||||
self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1)
|
||||
self.drop_mode = drop_mode
|
||||
|
||||
def forward(self, motion_latents, add_last_motion=2):
|
||||
motion_frames = motion_latents[0].shape[1]
|
||||
mot = []
|
||||
mot_remb = []
|
||||
for m in motion_latents:
|
||||
lat_height, lat_width = m.shape[2], m.shape[3]
|
||||
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype)
|
||||
overlap_frame = min(padd_lat.shape[1], m.shape[1])
|
||||
if overlap_frame > 0:
|
||||
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum()
|
||||
padd_lat[:, -zero_end_frame:] = 0
|
||||
|
||||
padd_lat = padd_lat.unsqueeze(0)
|
||||
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split(
|
||||
list(self.zip_frame_buckets)[::-1], dim=2
|
||||
) # 16, 2 ,1
|
||||
|
||||
# patchfy
|
||||
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)
|
||||
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)
|
||||
|
||||
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||
clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post
|
||||
clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x
|
||||
|
||||
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||
|
||||
# rope
|
||||
start_time_id = -(self.zip_frame_buckets[:1].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[0]
|
||||
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
|
||||
[
|
||||
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||
]
|
||||
|
||||
start_time_id = -(self.zip_frame_buckets[:2].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
|
||||
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
|
||||
[
|
||||
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
|
||||
]
|
||||
|
||||
start_time_id = -(self.zip_frame_buckets[:3].sum())
|
||||
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
|
||||
grid_sizes_4x = [
|
||||
[
|
||||
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1),
|
||||
torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
|
||||
]
|
||||
]
|
||||
|
||||
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
|
||||
|
||||
motion_rope_emb = rope_precompute(
|
||||
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads),
|
||||
grid_sizes,
|
||||
self.freqs,
|
||||
start=None
|
||||
)
|
||||
|
||||
mot.append(motion_lat)
|
||||
mot_remb.append(motion_rope_emb)
|
||||
return mot, mot_remb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
output_dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(F.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AudioInjector_WAN(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
dim=2048,
|
||||
num_heads=32,
|
||||
inject_layer=[0, 27],
|
||||
enable_adain=False,
|
||||
adain_dim=2048,
|
||||
):
|
||||
super().__init__()
|
||||
self.injected_block_id = {}
|
||||
audio_injector_id = 0
|
||||
for mod_name, mod in zip(all_modules_names, all_modules):
|
||||
if isinstance(mod, DiTBlock):
|
||||
for inject_id in inject_layer:
|
||||
if f'transformer_blocks.{inject_id}' in mod_name:
|
||||
self.injected_block_id[inject_id] = audio_injector_id
|
||||
audio_injector_id += 1
|
||||
|
||||
self.injector = nn.ModuleList([CrossAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
) for _ in range(audio_injector_id)])
|
||||
self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
) for _ in range(audio_injector_id)])
|
||||
self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm(
|
||||
dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
) for _ in range(audio_injector_id)])
|
||||
if enable_adain:
|
||||
self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)])
|
||||
|
||||
|
||||
class CausalAudioEncoder(nn.Module):
|
||||
|
||||
def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False):
|
||||
super().__init__()
|
||||
self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global)
|
||||
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
|
||||
|
||||
self.weights = torch.nn.Parameter(weight)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, features):
|
||||
# features B * num_layers * dim * video_length
|
||||
weights = self.act(self.weights.to(device=features.device, dtype=features.dtype))
|
||||
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||
weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
|
||||
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||
res = self.encoder(weighted_feat) # b f n dim
|
||||
return res # b f n dim
|
||||
|
||||
|
||||
class WanS2VDiTBlock(DiTBlock):
|
||||
|
||||
def forward(self, x, context, t_mod, seq_len_x, freqs):
|
||||
t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
# t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc.
|
||||
t_mod = [
|
||||
torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1)
|
||||
for element in t_mod
|
||||
]
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||
return x
|
||||
|
||||
|
||||
class WanS2VModel(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: Tuple[int, int, int],
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
cond_dim: int,
|
||||
audio_dim: int,
|
||||
num_audio_token: int,
|
||||
enable_adain: bool = True,
|
||||
audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||
zero_timestep: bool = True,
|
||||
add_last_motion: bool = True,
|
||||
framepack_drop_mode: str = "padd",
|
||||
fuse_vae_embedding_in_latents: bool = True,
|
||||
require_vae_embedding: bool = False,
|
||||
seperated_timestep: bool = False,
|
||||
require_clip_embedding: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.in_dim = in_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.enbale_adain = enable_adain
|
||||
self.add_last_motion = add_last_motion
|
||||
self.zero_timestep = zero_timestep
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
self.require_vae_embedding = require_vae_embedding
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.require_clip_embedding = require_clip_embedding
|
||||
|
||||
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim))
|
||||
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1)
|
||||
|
||||
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
|
||||
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
||||
self.audio_injector = AudioInjector_WAN(
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
inject_layer=audio_inject_layers,
|
||||
enable_adain=enable_adain,
|
||||
adain_dim=dim,
|
||||
)
|
||||
self.trainable_cond_mask = nn.Embedding(3, dim)
|
||||
self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode)
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x,
|
||||
'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||
f=grid_size[0],
|
||||
h=grid_size[1],
|
||||
w=grid_size[2],
|
||||
x=self.patch_size[0],
|
||||
y=self.patch_size[1],
|
||||
z=self.patch_size[2]
|
||||
)
|
||||
|
||||
def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):
|
||||
flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)
|
||||
if drop_motion_frames:
|
||||
return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb]
|
||||
else:
|
||||
return flattern_mot, mot_remb
|
||||
|
||||
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||
# inject the motion frames token to the hidden states
|
||||
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
|
||||
if len(mot) > 0:
|
||||
x = torch.cat([x, mot[0]], dim=1)
|
||||
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
|
||||
mask_input = torch.cat(
|
||||
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
|
||||
)
|
||||
return x, rope_embs, mask_input
|
||||
|
||||
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False):
|
||||
if block_idx in self.audio_injector.injected_block_id.keys():
|
||||
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
||||
num_frames = audio_emb.shape[1]
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sp_group
|
||||
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
||||
|
||||
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
|
||||
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
||||
|
||||
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||
attn_hidden_states = adain_hidden_states
|
||||
|
||||
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||
attn_audio_emb = audio_emb
|
||||
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
|
||||
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank
|
||||
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
return hidden_states
|
||||
|
||||
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
|
||||
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
||||
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input)
|
||||
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
||||
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
||||
return audio_emb_global, merged_audio_emb
|
||||
|
||||
def get_grid_sizes(self, grid_size_x, grid_size_ref):
|
||||
f, h, w = grid_size_x
|
||||
rf, rh, rw = grid_size_ref
|
||||
grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
|
||||
grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]]
|
||||
grid_sizes_ref = [[
|
||||
torch.tensor([30, 0, 0]).unsqueeze(0),
|
||||
torch.tensor([31, rh, rw]).unsqueeze(0),
|
||||
torch.tensor([1, rh, rw]).unsqueeze(0),
|
||||
]]
|
||||
return grid_sizes_x + grid_sizes_ref
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
timestep,
|
||||
context,
|
||||
audio_input,
|
||||
motion_latents,
|
||||
pose_cond,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
use_gradient_checkpointing=False
|
||||
):
|
||||
origin_ref_latents = latents[:, :, 0:1]
|
||||
x = latents[:, :, 1:]
|
||||
|
||||
# context embedding
|
||||
context = self.text_embedding(context)
|
||||
|
||||
# audio encode
|
||||
audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input)
|
||||
|
||||
# x and pose_cond
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
|
||||
seq_len_x = x.shape[1]
|
||||
|
||||
# reference image
|
||||
ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
|
||||
grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
# mask
|
||||
mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
|
||||
# freqs
|
||||
pre_compute_freqs = rope_precompute(
|
||||
x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None
|
||||
)
|
||||
# motion
|
||||
x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
|
||||
|
||||
x = x + self.trainable_cond_mask(mask).to(x.dtype)
|
||||
|
||||
# t_mod
|
||||
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
|
||||
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x,
|
||||
context,
|
||||
t_mod,
|
||||
seq_len_x,
|
||||
pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
|
||||
x,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
|
||||
x = x[:, :seq_len_x]
|
||||
x = self.head(x, t[:-1])
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
902
diffsynth/models/wan_video_image_encoder.py
Normal file
902
diffsynth/models/wan_video_image_encoder.py
Normal file
@@ -0,0 +1,902 @@
|
||||
"""
|
||||
Concise re-implementation of
|
||||
``https://github.com/openai/CLIP'' and
|
||||
``https://github.com/mlfoundations/open_clip''.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from .wan_video_dit import flash_attention
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
|
||||
# compute attention
|
||||
p = self.dropout.p if self.training else 0.0
|
||||
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
||||
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
||||
|
||||
# output
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.post_norm = post_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
||||
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
||||
nn.Dropout(dropout))
|
||||
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, x, mask):
|
||||
if self.post_norm:
|
||||
x = self.norm1(x + self.attn(x, mask))
|
||||
x = self.norm2(x + self.ffn(x))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x), mask)
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class XLMRoberta(nn.Module):
|
||||
"""
|
||||
XLMRobertaModel with no pooler and no LM head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=250002,
|
||||
max_seq_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
dim=1024,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
post_norm=True,
|
||||
dropout=0.1,
|
||||
eps=1e-5):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.type_size = type_size
|
||||
self.pad_id = pad_id
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.post_norm = post_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
||||
self.type_embedding = nn.Embedding(type_size, dim)
|
||||
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# norm layer
|
||||
self.norm = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, ids):
|
||||
"""
|
||||
ids: [B, L] of torch.LongTensor.
|
||||
"""
|
||||
b, s = ids.shape
|
||||
mask = ids.ne(self.pad_id).long()
|
||||
|
||||
# embeddings
|
||||
x = self.token_embedding(ids) + \
|
||||
self.type_embedding(torch.zeros_like(ids)) + \
|
||||
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
||||
if self.post_norm:
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
# blocks
|
||||
mask = torch.where(
|
||||
mask.view(b, 1, 1, s).gt(0), 0.0,
|
||||
torch.finfo(x.dtype).min)
|
||||
for block in self.blocks:
|
||||
x = block(x, mask)
|
||||
|
||||
# output
|
||||
if not self.post_norm:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def xlm_roberta_large(pretrained=False,
|
||||
return_tokenizer=False,
|
||||
device='cpu',
|
||||
**kwargs):
|
||||
"""
|
||||
XLMRobertaLarge adapted from Huggingface.
|
||||
"""
|
||||
# params
|
||||
cfg = dict(
|
||||
vocab_size=250002,
|
||||
max_seq_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
dim=1024,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
post_norm=True,
|
||||
dropout=0.1,
|
||||
eps=1e-5)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
if pretrained:
|
||||
from sora import DOWNLOAD_TO_CACHE
|
||||
|
||||
# init a meta model
|
||||
with torch.device('meta'):
|
||||
model = XLMRoberta(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
|
||||
map_location=device),
|
||||
assign=True)
|
||||
else:
|
||||
# init a model on device
|
||||
with torch.device(device):
|
||||
model = XLMRoberta(**cfg)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from sora.data import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(
|
||||
name='xlm-roberta-large',
|
||||
seq_len=model.text_len,
|
||||
clean='whitespace')
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def pos_interpolate(pos, seq_len):
|
||||
if pos.size(1) == seq_len:
|
||||
return pos
|
||||
else:
|
||||
src_grid = int(math.sqrt(pos.size(1)))
|
||||
tar_grid = int(math.sqrt(seq_len))
|
||||
n = pos.size(1) - src_grid * src_grid
|
||||
return torch.cat([
|
||||
pos[:, :n],
|
||||
F.interpolate(
|
||||
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
||||
0, 3, 1, 2),
|
||||
size=(tar_grid, tar_grid),
|
||||
mode='bicubic',
|
||||
align_corners=False).flatten(2).transpose(1, 2)
|
||||
],
|
||||
dim=1)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
causal=False,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.causal = causal
|
||||
self.attn_dropout = attn_dropout
|
||||
self.proj_dropout = proj_dropout
|
||||
|
||||
# layers
|
||||
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = F.dropout(x, self.proj_dropout, self.training)
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim, mid_dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mid_dim = mid_dim
|
||||
|
||||
# layers
|
||||
self.fc1 = nn.Linear(dim, mid_dim)
|
||||
self.fc2 = nn.Linear(dim, mid_dim)
|
||||
self.fc3 = nn.Linear(mid_dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.silu(self.fc1(x)) * self.fc2(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
mlp_ratio,
|
||||
num_heads,
|
||||
post_norm=False,
|
||||
causal=False,
|
||||
activation='quick_gelu',
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.post_norm = post_norm
|
||||
self.causal = causal
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# layers
|
||||
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
||||
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
||||
proj_dropout)
|
||||
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
||||
if activation == 'swi_glu':
|
||||
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
||||
else:
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
||||
|
||||
def forward(self, x):
|
||||
if self.post_norm:
|
||||
x = x + self.norm1(self.attn(x))
|
||||
x = x + self.norm2(self.mlp(x))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
mlp_ratio,
|
||||
num_heads,
|
||||
activation='gelu',
|
||||
proj_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.proj_dropout = proj_dropout
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# layers
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||
self.to_q = nn.Linear(dim, dim)
|
||||
self.to_kv = nn.Linear(dim, dim * 2)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.norm = LayerNorm(dim, eps=norm_eps)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||
x = x.reshape(b, 1, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = F.dropout(x, self.proj_dropout, self.training)
|
||||
|
||||
# mlp
|
||||
x = x + self.mlp(self.norm(x))
|
||||
return x[:, 0]
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
dim=768,
|
||||
mlp_ratio=4,
|
||||
out_dim=512,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
pool_type='token',
|
||||
pre_norm=True,
|
||||
post_norm=False,
|
||||
activation='quick_gelu',
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
if image_size % patch_size != 0:
|
||||
print(
|
||||
'[WARNING] image_size is not divisible by patch_size',
|
||||
flush=True)
|
||||
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
||||
out_dim = out_dim or dim
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = (image_size // patch_size)**2
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.pool_type = pool_type
|
||||
self.post_norm = post_norm
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# embeddings
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
3,
|
||||
dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=not pre_norm)
|
||||
if pool_type in ('token', 'token_fc'):
|
||||
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
||||
1, self.num_patches +
|
||||
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
||||
self.dropout = nn.Dropout(embedding_dropout)
|
||||
|
||||
# transformer
|
||||
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
||||
self.transformer = nn.Sequential(*[
|
||||
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
||||
activation, attn_dropout, proj_dropout, norm_eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
||||
|
||||
# head
|
||||
if pool_type == 'token':
|
||||
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||
elif pool_type == 'token_fc':
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
elif pool_type == 'attn_pool':
|
||||
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
||||
proj_dropout, norm_eps)
|
||||
|
||||
def forward(self, x, interpolation=False, use_31_block=False):
|
||||
b = x.size(0)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
||||
if self.pool_type in ('token', 'token_fc'):
|
||||
x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
|
||||
if interpolation:
|
||||
e = pos_interpolate(self.pos_embedding, x.size(1))
|
||||
else:
|
||||
e = self.pos_embedding
|
||||
e = e.to(dtype=x.dtype, device=x.device)
|
||||
x = self.dropout(x + e)
|
||||
if self.pre_norm is not None:
|
||||
x = self.pre_norm(x)
|
||||
|
||||
# transformer
|
||||
if use_31_block:
|
||||
x = self.transformer[:-1](x)
|
||||
return x
|
||||
else:
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
vision_dim=768,
|
||||
vision_mlp_ratio=4,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vision_pool='token',
|
||||
vision_pre_norm=True,
|
||||
vision_post_norm=False,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_mlp_ratio=4,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
text_causal=True,
|
||||
text_pool='argmax',
|
||||
text_head_bias=False,
|
||||
logit_bias=None,
|
||||
activation='quick_gelu',
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.vision_dim = vision_dim
|
||||
self.vision_mlp_ratio = vision_mlp_ratio
|
||||
self.vision_heads = vision_heads
|
||||
self.vision_layers = vision_layers
|
||||
self.vision_pool = vision_pool
|
||||
self.vision_pre_norm = vision_pre_norm
|
||||
self.vision_post_norm = vision_post_norm
|
||||
self.vocab_size = vocab_size
|
||||
self.text_len = text_len
|
||||
self.text_dim = text_dim
|
||||
self.text_mlp_ratio = text_mlp_ratio
|
||||
self.text_heads = text_heads
|
||||
self.text_layers = text_layers
|
||||
self.text_causal = text_causal
|
||||
self.text_pool = text_pool
|
||||
self.text_head_bias = text_head_bias
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# models
|
||||
self.visual = VisionTransformer(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
dim=vision_dim,
|
||||
mlp_ratio=vision_mlp_ratio,
|
||||
out_dim=embed_dim,
|
||||
num_heads=vision_heads,
|
||||
num_layers=vision_layers,
|
||||
pool_type=vision_pool,
|
||||
pre_norm=vision_pre_norm,
|
||||
post_norm=vision_post_norm,
|
||||
activation=activation,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout,
|
||||
norm_eps=norm_eps)
|
||||
self.textual = TextTransformer(
|
||||
vocab_size=vocab_size,
|
||||
text_len=text_len,
|
||||
dim=text_dim,
|
||||
mlp_ratio=text_mlp_ratio,
|
||||
out_dim=embed_dim,
|
||||
num_heads=text_heads,
|
||||
num_layers=text_layers,
|
||||
causal=text_causal,
|
||||
pool_type=text_pool,
|
||||
head_bias=text_head_bias,
|
||||
activation=activation,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout,
|
||||
norm_eps=norm_eps)
|
||||
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||
if logit_bias is not None:
|
||||
self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, imgs, txt_ids):
|
||||
"""
|
||||
imgs: [B, 3, H, W] of torch.float32.
|
||||
- mean: [0.48145466, 0.4578275, 0.40821073]
|
||||
- std: [0.26862954, 0.26130258, 0.27577711]
|
||||
txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
|
||||
"""
|
||||
xi = self.visual(imgs)
|
||||
xt = self.textual(txt_ids)
|
||||
return xi, xt
|
||||
|
||||
def init_weights(self):
|
||||
# embeddings
|
||||
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
|
||||
|
||||
# attentions
|
||||
for modality in ['visual', 'textual']:
|
||||
dim = self.vision_dim if modality == 'visual' else self.text_dim
|
||||
transformer = getattr(self, modality).transformer
|
||||
proj_gain = (1.0 / math.sqrt(dim)) * (
|
||||
1.0 / math.sqrt(2 * len(transformer)))
|
||||
attn_gain = 1.0 / math.sqrt(dim)
|
||||
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
||||
for block in transformer:
|
||||
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
||||
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
||||
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
||||
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
||||
|
||||
def param_groups(self):
|
||||
groups = [{
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if 'norm' in n or n.endswith('bias')
|
||||
],
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if not ('norm' in n or n.endswith('bias'))
|
||||
]
|
||||
}]
|
||||
return groups
|
||||
|
||||
|
||||
class XLMRobertaWithHead(XLMRoberta):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.out_dim = kwargs.pop('out_dim')
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# head
|
||||
mid_dim = (self.dim + self.out_dim) // 2
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
||||
nn.Linear(mid_dim, self.out_dim, bias=False))
|
||||
|
||||
def forward(self, ids):
|
||||
# xlm-roberta
|
||||
x = super().forward(ids)
|
||||
|
||||
# average pooling
|
||||
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
||||
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
||||
|
||||
# head
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class XLMRobertaCLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
embed_dim=1024,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
vision_dim=1280,
|
||||
vision_mlp_ratio=4,
|
||||
vision_heads=16,
|
||||
vision_layers=32,
|
||||
vision_pool='token',
|
||||
vision_pre_norm=True,
|
||||
vision_post_norm=False,
|
||||
activation='gelu',
|
||||
vocab_size=250002,
|
||||
max_text_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
text_dim=1024,
|
||||
text_heads=16,
|
||||
text_layers=24,
|
||||
text_post_norm=True,
|
||||
text_dropout=0.1,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.vision_dim = vision_dim
|
||||
self.vision_mlp_ratio = vision_mlp_ratio
|
||||
self.vision_heads = vision_heads
|
||||
self.vision_layers = vision_layers
|
||||
self.vision_pre_norm = vision_pre_norm
|
||||
self.vision_post_norm = vision_post_norm
|
||||
self.activation = activation
|
||||
self.vocab_size = vocab_size
|
||||
self.max_text_len = max_text_len
|
||||
self.type_size = type_size
|
||||
self.pad_id = pad_id
|
||||
self.text_dim = text_dim
|
||||
self.text_heads = text_heads
|
||||
self.text_layers = text_layers
|
||||
self.text_post_norm = text_post_norm
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# models
|
||||
self.visual = VisionTransformer(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
dim=vision_dim,
|
||||
mlp_ratio=vision_mlp_ratio,
|
||||
out_dim=embed_dim,
|
||||
num_heads=vision_heads,
|
||||
num_layers=vision_layers,
|
||||
pool_type=vision_pool,
|
||||
pre_norm=vision_pre_norm,
|
||||
post_norm=vision_post_norm,
|
||||
activation=activation,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout,
|
||||
norm_eps=norm_eps)
|
||||
self.textual = None
|
||||
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||
|
||||
def forward(self, imgs, txt_ids):
|
||||
"""
|
||||
imgs: [B, 3, H, W] of torch.float32.
|
||||
- mean: [0.48145466, 0.4578275, 0.40821073]
|
||||
- std: [0.26862954, 0.26130258, 0.27577711]
|
||||
txt_ids: [B, L] of torch.long.
|
||||
Encoded by data.CLIPTokenizer.
|
||||
"""
|
||||
xi = self.visual(imgs)
|
||||
xt = self.textual(txt_ids)
|
||||
return xi, xt
|
||||
|
||||
def param_groups(self):
|
||||
groups = [{
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if 'norm' in n or n.endswith('bias')
|
||||
],
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if not ('norm' in n or n.endswith('bias'))
|
||||
]
|
||||
}]
|
||||
return groups
|
||||
|
||||
|
||||
def _clip(pretrained=False,
|
||||
pretrained_name=None,
|
||||
model_cls=CLIP,
|
||||
return_transforms=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_padding='eos',
|
||||
dtype=torch.float32,
|
||||
device='cpu',
|
||||
**kwargs):
|
||||
# init model
|
||||
if pretrained and pretrained_name:
|
||||
from sora import BUCKET, DOWNLOAD_TO_CACHE
|
||||
|
||||
# init a meta model
|
||||
with torch.device('meta'):
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# checkpoint path
|
||||
checkpoint = f'models/clip/{pretrained_name}'
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
suffix = '-' + {
|
||||
torch.float16: 'fp16',
|
||||
torch.bfloat16: 'bf16'
|
||||
}[dtype]
|
||||
if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
|
||||
checkpoint = f'{checkpoint}{suffix}'
|
||||
checkpoint += '.pth'
|
||||
|
||||
# load
|
||||
model.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
|
||||
assign=True,
|
||||
strict=False)
|
||||
else:
|
||||
# init a model on device
|
||||
with torch.device(device):
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# set device
|
||||
output = (model,)
|
||||
|
||||
# init transforms
|
||||
if return_transforms:
|
||||
# mean and std
|
||||
if 'siglip' in pretrained_name.lower():
|
||||
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
||||
else:
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# transforms
|
||||
transforms = T.Compose([
|
||||
T.Resize((model.image_size, model.image_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=mean, std=std)
|
||||
])
|
||||
output += (transforms,)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from sora import data
|
||||
if 'siglip' in pretrained_name.lower():
|
||||
tokenizer = data.HuggingfaceTokenizer(
|
||||
name=f'timm/{pretrained_name}',
|
||||
seq_len=model.text_len,
|
||||
clean='canonicalize')
|
||||
elif 'xlm' in pretrained_name.lower():
|
||||
tokenizer = data.HuggingfaceTokenizer(
|
||||
name='xlm-roberta-large',
|
||||
seq_len=model.max_text_len - 2,
|
||||
clean='whitespace')
|
||||
elif 'mba' in pretrained_name.lower():
|
||||
tokenizer = data.HuggingfaceTokenizer(
|
||||
name='facebook/xlm-roberta-xl',
|
||||
seq_len=model.max_text_len - 2,
|
||||
clean='whitespace')
|
||||
else:
|
||||
tokenizer = data.CLIPTokenizer(
|
||||
seq_len=model.text_len, padding=tokenizer_padding)
|
||||
output += (tokenizer,)
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
|
||||
def clip_xlm_roberta_vit_h_14(
|
||||
pretrained=False,
|
||||
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
||||
**kwargs):
|
||||
cfg = dict(
|
||||
embed_dim=1024,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
vision_dim=1280,
|
||||
vision_mlp_ratio=4,
|
||||
vision_heads=16,
|
||||
vision_layers=32,
|
||||
vision_pool='token',
|
||||
activation='gelu',
|
||||
vocab_size=250002,
|
||||
max_text_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
text_dim=1024,
|
||||
text_heads=16,
|
||||
text_layers=24,
|
||||
text_post_norm=True,
|
||||
text_dropout=0.1,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
||||
|
||||
|
||||
class WanImageEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# init model
|
||||
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
||||
pretrained=False,
|
||||
return_transforms=True,
|
||||
return_tokenizer=False,
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
|
||||
def encode_image(self, videos):
|
||||
# preprocess
|
||||
size = (self.model.image_size,) * 2
|
||||
videos = torch.cat([
|
||||
F.interpolate(
|
||||
u,
|
||||
size=size,
|
||||
mode='bicubic',
|
||||
align_corners=False) for u in videos
|
||||
])
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||
videos = videos.to(dtype)
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanImageEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanImageEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("textual."):
|
||||
continue
|
||||
name = "model." + name
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
169
diffsynth/models/wan_video_mot.py
Normal file
169
diffsynth/models/wan_video_mot.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP
|
||||
import einops
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MotSelfAttention(SelfAttention):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__(dim, num_heads, eps)
|
||||
def forward(self, x, freqs, is_before_attn=False):
|
||||
if is_before_attn:
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
return q, k, v
|
||||
else:
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class MotWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
||||
self.block_id = block_id
|
||||
|
||||
self.self_attn = MotSelfAttention(dim, num_heads, eps)
|
||||
|
||||
|
||||
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot):
|
||||
|
||||
# 1. prepare scale parameter
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
|
||||
scale_params_mot_ref = self.modulation + t_mod_mot.float()
|
||||
scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1)
|
||||
shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2)
|
||||
|
||||
# 2. Self-attention
|
||||
input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa)
|
||||
# original block self-attn
|
||||
attn1 = wan_block.self_attn
|
||||
q = attn1.norm_q(attn1.q(input_x))
|
||||
k = attn1.norm_k(attn1.k(input_x))
|
||||
v = attn1.v(input_x)
|
||||
q = rope_apply(q, freqs, attn1.num_heads)
|
||||
k = rope_apply(k, freqs, attn1.num_heads)
|
||||
|
||||
# mot block self-attn
|
||||
norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
||||
norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot)
|
||||
norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1)
|
||||
q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True)
|
||||
|
||||
tmp_hidden_states = flash_attention(
|
||||
torch.cat([q, q_mot], dim=-2),
|
||||
torch.cat([k, k_mot], dim=-2),
|
||||
torch.cat([v, v_mot], dim=-2),
|
||||
num_heads=attn1.num_heads)
|
||||
|
||||
attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2)
|
||||
|
||||
attn_output = attn1.o(attn_output)
|
||||
x = wan_block.gate(x, gate_msa, attn_output)
|
||||
|
||||
attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False)
|
||||
# gate
|
||||
attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1)
|
||||
attn_output_mot = attn_output_mot * gate_msa_mot_ref
|
||||
attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1)
|
||||
x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot)
|
||||
|
||||
# 3. cross-attention and feed-forward
|
||||
x = x + wan_block.cross_attn(wan_block.norm3(x), context)
|
||||
input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp)
|
||||
x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x))
|
||||
|
||||
x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot)
|
||||
# modulate
|
||||
norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1)
|
||||
norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot)
|
||||
norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1)
|
||||
input_x_mot = self.ffn(norm_x_mot_ref)
|
||||
# gate
|
||||
input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1)
|
||||
input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref
|
||||
input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1)
|
||||
x_mot = (x_mot.float() + input_x_mot).type_as(x_mot)
|
||||
|
||||
return x, x_mot
|
||||
|
||||
|
||||
class MotWanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
||||
patch_size=(1, 2, 2),
|
||||
has_image_input=True,
|
||||
has_image_pos_emb=False,
|
||||
dim=5120,
|
||||
num_heads=40,
|
||||
ffn_dim=13824,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
in_dim=36,
|
||||
eps=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.mot_layers = mot_layers
|
||||
self.freq_dim = freq_dim
|
||||
self.dim = dim
|
||||
|
||||
self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)}
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)
|
||||
|
||||
# mot blocks
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
||||
for i in self.mot_layers
|
||||
])
|
||||
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
return x
|
||||
|
||||
def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0):
|
||||
def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0):
|
||||
# 1d rope precompute
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
||||
[: (dim // 2)].double() / dim))
|
||||
freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta)
|
||||
h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
||||
w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
|
||||
|
||||
freqs = torch.cat([
|
||||
f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1)
|
||||
return freqs
|
||||
|
||||
def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id):
|
||||
block = self.blocks[self.mot_layers_mapping[block_id]]
|
||||
x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)
|
||||
return x, x_mot
|
||||
44
diffsynth/models/wan_video_motion_controller.py
Normal file
44
diffsynth/models/wan_video_motion_controller.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .wan_video_dit import sinusoidal_embedding_1d
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModel(torch.nn.Module):
|
||||
def __init__(self, freq_dim=256, dim=1536):
|
||||
super().__init__()
|
||||
self.freq_dim = freq_dim
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6),
|
||||
)
|
||||
|
||||
def forward(self, motion_bucket_id):
|
||||
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
||||
emb = self.linear(emb)
|
||||
return emb
|
||||
|
||||
def init(self):
|
||||
state_dict = self.linear[-1].state_dict()
|
||||
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||
self.linear[-1].load_state_dict(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanMotionControllerModelDictConverter()
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
330
diffsynth/models/wan_video_text_encoder.py
Normal file
330
diffsynth/models/wan_video_text_encoder.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer
|
||||
import ftfy
|
||||
import html
|
||||
import string
|
||||
import regex as re
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
clamp = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp, max=clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super(T5LayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||
self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
||||
assert dim_attn % num_heads == 0
|
||||
super(T5Attention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
|
||||
# attention bias
|
||||
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
||||
if pos_bias is not None:
|
||||
attn_bias += pos_bias
|
||||
if mask is not None:
|
||||
assert mask.ndim in [2, 3]
|
||||
mask = mask.view(b, 1, 1,
|
||||
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
||||
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, n * c)
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_ffn, dropout=0.1):
|
||||
super(T5FeedForward, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) * self.gate(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def forward(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super(T5RelativeEmbedding, self).__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def forward(self, lq, lk):
|
||||
device = self.embedding.weight.device
|
||||
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
||||
# torch.arange(lq).unsqueeze(1).to(device)
|
||||
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
||||
torch.arange(lq, device=device).unsqueeze(1)
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
||||
0) # [1, N, Lq, Lk]
|
||||
return rel_pos_embeds.contiguous()
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).long() * num_buckets
|
||||
rel_pos = torch.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = 0
|
||||
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)).long()
|
||||
rel_pos_large = torch.min(
|
||||
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
||||
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
||||
return rel_buckets
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, T5LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
elif isinstance(m, T5FeedForward):
|
||||
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
||||
elif isinstance(m, T5Attention):
|
||||
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
||||
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
||||
elif isinstance(m, T5RelativeEmbedding):
|
||||
nn.init.normal_(
|
||||
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
||||
|
||||
|
||||
class WanTextEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
num_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.1):
|
||||
super(WanTextEncoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
87
diffsynth/models/wan_video_vace.py
Normal file
87
diffsynth/models/wan_video_vace.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = torch.nn.Linear(self.dim, self.dim)
|
||||
self.after_proj = torch.nn.Linear(self.dim, self.dim)
|
||||
|
||||
def forward(self, c, x, context, t_mod, freqs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
c = super().forward(c, context, t_mod, freqs)
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class VaceWanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
||||
vace_in_dim=96,
|
||||
patch_size=(1, 2, 2),
|
||||
has_image_input=False,
|
||||
dim=1536,
|
||||
num_heads=12,
|
||||
ffn_dim=8960,
|
||||
eps=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.vace_layers = vace_layers
|
||||
self.vace_in_dim = vace_in_dim
|
||||
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
||||
|
||||
# vace blocks
|
||||
self.vace_blocks = torch.nn.ModuleList([
|
||||
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
||||
for i in self.vace_layers
|
||||
])
|
||||
|
||||
# vace patch embeddings
|
||||
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(
|
||||
self, x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
):
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.vace_blocks:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
c = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
c, x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
1382
diffsynth/models/wan_video_vae.py
Normal file
1382
diffsynth/models/wan_video_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
204
diffsynth/models/wav2vec.py
Normal file
204
diffsynth/models/wav2vec.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
|
||||
required_duration = num_sample / target_fps
|
||||
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
if max_start < 0:
|
||||
raise ValueError("video length is too short")
|
||||
start_frame = np.random.randint(0, max_start + 1)
|
||||
start_time = start_frame / original_fps
|
||||
|
||||
end_time = start_time + required_duration
|
||||
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
||||
|
||||
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
||||
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
||||
return frame_indices
|
||||
|
||||
|
||||
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||
"""
|
||||
features: shape=[1, T, 512]
|
||||
input_fps: fps for audio, f_a
|
||||
output_fps: fps for video, f_m
|
||||
output_len: video length
|
||||
"""
|
||||
features = features.transpose(1, 2)
|
||||
seq_len = features.shape[2] / float(input_fps)
|
||||
if output_len is None:
|
||||
output_len = int(seq_len * output_fps)
|
||||
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len]
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
|
||||
class WanS2VAudioEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config
|
||||
config = {
|
||||
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
|
||||
"activation_dropout": 0.05,
|
||||
"apply_spec_augment": True,
|
||||
"architectures": ["Wav2Vec2ForCTC"],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"conv_bias": True,
|
||||
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
|
||||
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
|
||||
"conv_stride": [5, 2, 2, 2, 2, 2, 2],
|
||||
"ctc_loss_reduction": "mean",
|
||||
"ctc_zero_infinity": True,
|
||||
"do_stable_layer_norm": True,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "layer",
|
||||
"feat_proj_dropout": 0.05,
|
||||
"final_dropout": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.05,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.05,
|
||||
"mask_channel_length": 10,
|
||||
"mask_channel_min_space": 1,
|
||||
"mask_channel_other": 0.0,
|
||||
"mask_channel_prob": 0.0,
|
||||
"mask_channel_selection": "static",
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_min_space": 1,
|
||||
"mask_time_other": 0.0,
|
||||
"mask_time_prob": 0.05,
|
||||
"mask_time_selection": "static",
|
||||
"model_type": "wav2vec2",
|
||||
"num_attention_heads": 16,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 24,
|
||||
"pad_token_id": 0,
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"vocab_size": 33
|
||||
}
|
||||
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
||||
self.video_rate = 30
|
||||
|
||||
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):
|
||||
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device)
|
||||
|
||||
# retrieve logits & take argmax
|
||||
res = self.model(input_values, output_hidden_states=True)
|
||||
if return_all_layers:
|
||||
feat = torch.cat(res.hidden_states)
|
||||
else:
|
||||
feat = res.hidden_states[-1]
|
||||
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
||||
return feat
|
||||
|
||||
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
if num_layers > 1:
|
||||
return_all_layers = True
|
||||
else:
|
||||
return_all_layers = False
|
||||
|
||||
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
|
||||
|
||||
bucket_num = min_batch_num * batch_frames
|
||||
batch_idx = [stride * i for i in range(bucket_num)]
|
||||
batch_audio_eb = []
|
||||
for bi in batch_idx:
|
||||
if bi < audio_frame_num:
|
||||
audio_sample_stride = 2
|
||||
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||
|
||||
if return_all_layers:
|
||||
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||
else:
|
||||
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||
else:
|
||||
frame_audio_embed = \
|
||||
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||
batch_audio_eb.append(frame_audio_embed)
|
||||
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
if num_layers > 1:
|
||||
return_all_layers = True
|
||||
else:
|
||||
return_all_layers = False
|
||||
|
||||
scale = self.video_rate / fps
|
||||
|
||||
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
||||
|
||||
bucket_num = min_batch_num * batch_frames
|
||||
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
|
||||
batch_idx = get_sample_indices(
|
||||
original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0
|
||||
)
|
||||
batch_audio_eb = []
|
||||
audio_sample_stride = int(self.video_rate / fps)
|
||||
for bi in batch_idx:
|
||||
if bi < audio_frame_num:
|
||||
|
||||
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
|
||||
|
||||
if return_all_layers:
|
||||
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
||||
else:
|
||||
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||
else:
|
||||
frame_audio_embed = \
|
||||
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||
batch_audio_eb.append(frame_audio_embed)
|
||||
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
|
||||
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
|
||||
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
|
||||
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
|
||||
return audio_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VAudioEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanS2VAudioEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {'model.' + k: v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
@@ -151,16 +151,11 @@ class QwenImagePipeline(BasePipeline):
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
|
||||
1509
diffsynth/pipelines/wan_video.py
Normal file
1509
diffsynth/pipelines/wan_video.py
Normal file
File diff suppressed because it is too large
Load Diff
217
diffsynth/utils/data/__init__.py
Normal file
217
diffsynth/utils/data/__init__.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import imageio, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
|
||||
class LowMemoryVideo:
|
||||
def __init__(self, file_name):
|
||||
self.reader = imageio.get_reader(file_name)
|
||||
|
||||
def __len__(self):
|
||||
return self.reader.count_frames()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
||||
|
||||
def __del__(self):
|
||||
self.reader.close()
|
||||
|
||||
|
||||
def split_file_name(file_name):
|
||||
result = []
|
||||
number = -1
|
||||
for i in file_name:
|
||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
||||
if number == -1:
|
||||
number = 0
|
||||
number = number*10 + ord(i) - ord("0")
|
||||
else:
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
number = -1
|
||||
result.append(i)
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
|
||||
def search_for_images(folder):
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
||||
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
||||
file_list = [i[1] for i in sorted(file_list)]
|
||||
file_list = [os.path.join(folder, i) for i in file_list]
|
||||
return file_list
|
||||
|
||||
|
||||
class LowMemoryImageFolder:
|
||||
def __init__(self, folder, file_list=None):
|
||||
if file_list is None:
|
||||
self.file_list = search_for_images(folder)
|
||||
else:
|
||||
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return Image.open(self.file_list[item]).convert("RGB")
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
def crop_and_resize(image, height, width):
|
||||
image = np.array(image)
|
||||
image_height, image_width, _ = image.shape
|
||||
if image_height / image_width < height / width:
|
||||
croped_width = int(image_height / height * width)
|
||||
left = (image_width - croped_width) // 2
|
||||
image = image[:, left: left+croped_width]
|
||||
image = Image.fromarray(image).resize((width, height))
|
||||
else:
|
||||
croped_height = int(image_width / width * height)
|
||||
left = (image_height - croped_height) // 2
|
||||
image = image[left: left+croped_height, :]
|
||||
image = Image.fromarray(image).resize((width, height))
|
||||
return image
|
||||
|
||||
|
||||
class VideoData:
|
||||
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
||||
if video_file is not None:
|
||||
self.data_type = "video"
|
||||
self.data = LowMemoryVideo(video_file, **kwargs)
|
||||
elif image_folder is not None:
|
||||
self.data_type = "images"
|
||||
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
||||
else:
|
||||
raise ValueError("Cannot open video or image folder")
|
||||
self.length = None
|
||||
self.set_shape(height, width)
|
||||
|
||||
def raw_data(self):
|
||||
frames = []
|
||||
for i in range(self.__len__()):
|
||||
frames.append(self.__getitem__(i))
|
||||
return frames
|
||||
|
||||
def set_length(self, length):
|
||||
self.length = length
|
||||
|
||||
def set_shape(self, height, width):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
def __len__(self):
|
||||
if self.length is None:
|
||||
return len(self.data)
|
||||
else:
|
||||
return self.length
|
||||
|
||||
def shape(self):
|
||||
if self.height is not None and self.width is not None:
|
||||
return self.height, self.width
|
||||
else:
|
||||
height, width, _ = self.__getitem__(0).shape
|
||||
return height, width
|
||||
|
||||
def __getitem__(self, item):
|
||||
frame = self.data.__getitem__(item)
|
||||
width, height = frame.size
|
||||
if self.height is not None and self.width is not None:
|
||||
if self.height != height or self.width != width:
|
||||
frame = crop_and_resize(frame, self.height, self.width)
|
||||
return frame
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
def save_images(self, folder):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
||||
frame = self.__getitem__(i)
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
||||
def save_frames(frames, save_path):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
||||
frame.save(os.path.join(save_path, f"{i}.png"))
|
||||
|
||||
|
||||
def merge_video_audio(video_path: str, audio_path: str):
|
||||
# TODO: may need a in-python implementation to avoid subprocess dependency
|
||||
"""
|
||||
Merge the video and audio into a new video, with the duration set to the shorter of the two,
|
||||
and overwrite the original video file.
|
||||
|
||||
Parameters:
|
||||
video_path (str): Path to the original video file
|
||||
audio_path (str): Path to the audio file
|
||||
"""
|
||||
|
||||
# check
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"video file {video_path} does not exist")
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"audio file {audio_path} does not exist")
|
||||
|
||||
base, ext = os.path.splitext(video_path)
|
||||
temp_output = f"{base}_temp{ext}"
|
||||
|
||||
try:
|
||||
# create ffmpeg command
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y', # overwrite
|
||||
'-i',
|
||||
video_path,
|
||||
'-i',
|
||||
audio_path,
|
||||
'-c:v',
|
||||
'copy', # copy video stream
|
||||
'-c:a',
|
||||
'aac', # use AAC audio encoder
|
||||
'-b:a',
|
||||
'192k', # set audio bitrate (optional)
|
||||
'-map',
|
||||
'0:v:0', # select the first video stream
|
||||
'-map',
|
||||
'1:a:0', # select the first audio stream
|
||||
'-shortest', # choose the shortest duration
|
||||
temp_output
|
||||
]
|
||||
|
||||
# execute the command
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# check result
|
||||
if result.returncode != 0:
|
||||
error_msg = f"FFmpeg execute failed: {result.stderr}"
|
||||
print(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
shutil.move(temp_output, video_path)
|
||||
print(f"Merge completed, saved to {video_path}")
|
||||
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_output):
|
||||
os.remove(temp_output)
|
||||
print(f"merge_video_audio failed with error: {e}")
|
||||
|
||||
|
||||
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
|
||||
save_video(frames, save_path, fps, quality, ffmpeg_params)
|
||||
merge_video_audio(save_path, audio_path)
|
||||
@@ -0,0 +1,6 @@
|
||||
def WanAnimateAdapterStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
83
diffsynth/utils/state_dict_converters/wan_video_dit.py
Normal file
83
diffsynth/utils/state_dict_converters/wan_video_dit.py
Normal file
@@ -0,0 +1,83 @@
|
||||
def WanVideoDiTFromDiffusers(state_dict):
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = state_dict[name]
|
||||
else:
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def WanVideoDiTStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vace"):
|
||||
continue
|
||||
if name.split(".")[0] in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]:
|
||||
continue
|
||||
name_ = name
|
||||
if name_.startswith("model."):
|
||||
name_ = name_[len("model."):]
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -0,0 +1,8 @@
|
||||
def WanImageEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("textual."):
|
||||
continue
|
||||
name_ = "model." + name
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
77
diffsynth/utils/state_dict_converters/wan_video_mot.py
Normal file
77
diffsynth/utils/state_dict_converters/wan_video_mot.py
Normal file
@@ -0,0 +1,77 @@
|
||||
def WanVideoMotStateDictConverter(state_dict):
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if "_mot_ref" not in name:
|
||||
continue
|
||||
name = name.replace("_mot_ref", "")
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = state_dict[name]
|
||||
else:
|
||||
if name.split(".")[1].isdigit():
|
||||
block_id = int(name.split(".")[1])
|
||||
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
3
diffsynth/utils/state_dict_converters/wan_video_vace.py
Normal file
3
diffsynth/utils/state_dict_converters/wan_video_vace.py
Normal file
@@ -0,0 +1,3 @@
|
||||
def VaceWanModelDictConverter(state_dict):
|
||||
state_dict_ = {name: state_dict[name] for name in state_dict if name.startswith("vace")}
|
||||
return state_dict_
|
||||
7
diffsynth/utils/state_dict_converters/wan_video_vae.py
Normal file
7
diffsynth/utils/state_dict_converters/wan_video_vae.py
Normal file
@@ -0,0 +1,7 @@
|
||||
def WanVideoVAEStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
if 'model_state' in state_dict:
|
||||
state_dict = state_dict['model_state']
|
||||
for name in state_dict:
|
||||
state_dict_['model.' + name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -0,0 +1,3 @@
|
||||
def WanS2VAudioEncoderStateDictConverter(state_dict):
|
||||
state_dict = {'model.' + k: state_dict[k] for k in state_dict}
|
||||
return state_dict
|
||||
1
diffsynth/utils/xfuser/__init__.py
Normal file
1
diffsynth/utils/xfuser/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
|
||||
145
diffsynth/utils/xfuser/xdit_context_parallel.py
Normal file
145
diffsynth/utils/xfuser/xdit_context_parallel.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from einops import rearrange
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
|
||||
def initialize_usp():
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
||||
dist.init_process_group(backend="nccl", init_method="env://")
|
||||
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=1,
|
||||
ulysses_degree=dist.get_world_size(),
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
padding_tensor = torch.ones(
|
||||
pad_size,
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
return padded_tensor
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
s_per_rank = x.shape[1]
|
||||
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||
|
||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
def usp_dit_forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
# Context Parallel
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
x = self.head(x, t)
|
||||
|
||||
# Context Parallel
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = x[:, :-pad_shape] if pad_shape > 0 else x
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def usp_attn_forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
|
||||
x = xFuserLongContextAttention()(
|
||||
None,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
)
|
||||
x = x.flatten(2)
|
||||
|
||||
del q, k, v
|
||||
torch.cuda.empty_cache()
|
||||
return self.o(x)
|
||||
@@ -182,19 +182,18 @@ https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import save_video
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
|
||||
@@ -20,14 +20,17 @@ def run_inference(script_path):
|
||||
def run_tasks_on_single_GPU(script_path, gpu_id, num_gpu):
|
||||
output_path = os.path.join("data", script_path)
|
||||
for script_id, script in enumerate(sorted(os.listdir(script_path))):
|
||||
if not script.endswith(".sh"):
|
||||
if not script.endswith(".sh") and not script.endswith(".py"):
|
||||
continue
|
||||
if script_id % num_gpu != gpu_id:
|
||||
continue
|
||||
source_path = os.path.join(script_path, script)
|
||||
target_path = os.path.join(output_path, script)
|
||||
os.makedirs(target_path, exist_ok=True)
|
||||
cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} bash {source_path} > {target_path}/log.txt 2>&1"
|
||||
if script.endswith(".sh"):
|
||||
cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} bash {source_path} > {target_path}/log.txt 2>&1"
|
||||
else:
|
||||
cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python {source_path} > {target_path}/log.txt 2>&1"
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
@@ -60,5 +63,6 @@ if __name__ == "__main__":
|
||||
# run_train_single_GPU("examples/qwen_image/model_training/lora")
|
||||
# run_inference("examples/qwen_image/model_inference")
|
||||
# run_inference("examples/qwen_image/model_inference_low_vram")
|
||||
run_inference("examples/qwen_image/model_training/validate_full")
|
||||
run_inference("examples/qwen_image/model_training/validate_lora")
|
||||
# run_inference("examples/qwen_image/model_training/validate_full")
|
||||
# run_inference("examples/qwen_image/model_training/validate_lora")
|
||||
run_train_single_GPU("examples/wanvideo/model_inference")
|
||||
34
examples/wanvideo/model_inference/LongCat-Video.py
Normal file
34
examples/wanvideo/model_inference/LongCat-Video.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.",
|
||||
negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
||||
seed=0, tiled=True, num_frames=93,
|
||||
cfg_scale=2, sigma_shift=1,
|
||||
)
|
||||
save_video(video, "video_1_LongCat-Video.mp4", fps=15, quality=5)
|
||||
|
||||
# Video-continuation (The number of frames in `longcat_video` should be 4n+1.)
|
||||
longcat_video = video[-17:]
|
||||
video = pipe(
|
||||
prompt="In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene.",
|
||||
negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
||||
seed=1, tiled=True, num_frames=93,
|
||||
cfg_scale=2, sigma_shift=1,
|
||||
longcat_video=longcat_video,
|
||||
)
|
||||
save_video(video, "video_2_LongCat-Video.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
from typing import List
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="wanvap/*", local_dir="data/example_video_dataset")
|
||||
ref_video_path = 'data/example_video_dataset/wanvap/vap_ref.mp4'
|
||||
target_image_path = 'data/example_video_dataset/wanvap/input_image.jpg'
|
||||
|
||||
def select_frames(video_frames, num):
|
||||
idx = torch.linspace(0, len(video_frames) - 1, num).long().tolist()
|
||||
return [video_frames[i] for i in idx]
|
||||
|
||||
image = Image.open(target_image_path).convert("RGB")
|
||||
ref_video = VideoData(ref_video_path, height=480, width=832)
|
||||
ref_frames = select_frames(ref_video, num=49)
|
||||
|
||||
vap_prompt = "A man stands with his back to the camera on a dirt path overlooking sun-drenched, rolling green tea plantations. He wears a blue and green plaid shirt, dark pants, and white shoes. As he turns to face the camera and spreads his arms, a brief, magical burst of sparkling golden light particles envelops him. Through this shimmer, he seamlessly transforms into a Labubu toy character. His head morphs into the iconic large, furry-eared head of the toy, featuring a wide grin with pointed teeth and red cheek markings. The character retains the man's original plaid shirt and clothing, which now fit its stylized, cartoonish body. The camera remains static throughout the transformation, positioned low among the tea bushes, maintaining a consistent view of the subject and the expansive scenery."
|
||||
prompt = "A young woman with curly hair, wearing a green hijab and a floral dress, plays a violin in front of a vintage green car on a tree-lined street. She executes a swift counter-clockwise turn to face the camera. During the turn, a brilliant shower of golden, sparkling particles erupts and momentarily obscures her figure. As the particles fade, she is revealed to have seamlessly transformed into a Labubu toy character. This new figure, now with the toy's signature large ears, big eyes, and toothy grin, maintains the original pose and continues playing the violin. The character's clothing—the green hijab, floral dress, and black overcoat—remains identical to the woman's. Throughout this transition, the camera stays static, and the street-side environment remains completely consistent."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
input_image=image,
|
||||
seed=42, tiled=True,
|
||||
height=480, width=832,
|
||||
num_frames=49,
|
||||
vap_video=ref_frames,
|
||||
vap_prompt=vap_prompt,
|
||||
negative_vap_prompt=negative_prompt,
|
||||
)
|
||||
save_video(video, "video_Video-As-Prompt-Wan2.1-14B.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=0
|
||||
)
|
||||
save_video(video, "video_slow_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=100
|
||||
)
|
||||
save_video(video, "video_fast_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py
Normal file
35
examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"]
|
||||
)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)),
|
||||
end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)),
|
||||
seed=0, tiled=True,
|
||||
height=960, width=960, num_frames=33,
|
||||
sigma_shift=16,
|
||||
)
|
||||
save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5)
|
||||
33
examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py
Normal file
33
examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py
Normal file
35
examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5)
|
||||
33
examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py
Normal file
33
examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py
Normal file
35
examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Left", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_left_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Up", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_up_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-InP.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Left", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_left_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Up", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_up_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5)
|
||||
35
examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py
Normal file
35
examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-14B-InP.mp4", fps=15, quality=5)
|
||||
33
examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py
Normal file
33
examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# Image-to-video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5)
|
||||
34
examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py
Normal file
34
examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# Image-to-video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True,
|
||||
height=720, width=1280,
|
||||
)
|
||||
save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5)
|
||||
33
examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py
Normal file
33
examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video_1_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5)
|
||||
|
||||
# Video-to-video
|
||||
video = VideoData("video_1_Wan2.1-T2V-1.3B.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_video=video, denoising_strength=0.7,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_2_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5)
|
||||
23
examples/wanvideo/model_inference/Wan2.1-T2V-14B.py
Normal file
23
examples/wanvideo/model_inference/Wan2.1-T2V-14B.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
|
||||
# Depth video -> Video
|
||||
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_1_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5)
|
||||
|
||||
# Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_2_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5)
|
||||
|
||||
# Depth video + Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_3_Wan2.1-VACE-1.3B-Preview.mp4", fps=15, quality=5)
|
||||
52
examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py
Normal file
52
examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
|
||||
# Depth video -> Video
|
||||
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_1_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5)
|
||||
|
||||
# Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_2_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5)
|
||||
|
||||
# Depth video + Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_3_Wan2.1-VACE-1.3B.mp4", fps=15, quality=5)
|
||||
53
examples/wanvideo/model_inference/Wan2.1-VACE-14B.py
Normal file
53
examples/wanvideo/model_inference/Wan2.1-VACE-14B.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
|
||||
# Depth video -> Video
|
||||
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_1_Wan2.1-VACE-14B.mp4", fps=15, quality=5)
|
||||
|
||||
# Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_2_Wan2.1-VACE-14B.mp4", fps=15, quality=5)
|
||||
|
||||
# Depth video + Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_3_Wan2.1-VACE-14B.mp4", fps=15, quality=5)
|
||||
58
examples/wanvideo/model_inference/Wan2.2-Animate-14B.py
Normal file
58
examples/wanvideo/model_inference/Wan2.2-Animate-14B.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern="data/examples/wan/animate/*",
|
||||
)
|
||||
|
||||
# Animate
|
||||
input_image = Image.open("data/examples/wan/animate/animate_input_image.png")
|
||||
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4]
|
||||
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4]
|
||||
video = pipe(
|
||||
prompt="视频中的人在做动作",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
animate_pose_video=animate_pose_video,
|
||||
animate_face_video=animate_face_video,
|
||||
num_frames=81, height=720, width=1280,
|
||||
num_inference_steps=20, cfg_scale=1,
|
||||
)
|
||||
save_video(video, "video_1_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
||||
|
||||
# Replace
|
||||
pipe.load_lora(pipe.dit, ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="relighting_lora.ckpt"))
|
||||
input_image = Image.open("data/examples/wan/animate/replace_input_image.png")
|
||||
animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4]
|
||||
animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4]
|
||||
animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4]
|
||||
animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4]
|
||||
video = pipe(
|
||||
prompt="视频中的人在做动作",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
animate_pose_video=animate_pose_video,
|
||||
animate_face_video=animate_face_video,
|
||||
animate_inpaint_video=animate_inpaint_video,
|
||||
animate_mask_video=animate_mask_video,
|
||||
num_frames=81, height=720, width=1280,
|
||||
num_inference_steps=20, cfg_scale=1,
|
||||
)
|
||||
save_video(video, "video_2_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video,VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Left", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_left_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Up", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_up_Wan2.2-Fun-A14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
34
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py
Normal file
34
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video,VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
|
||||
)
|
||||
|
||||
# Control video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
control_video=control_video, reference_image=reference_image,
|
||||
height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-Control.mp4", fps=15, quality=5)
|
||||
34
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py
Normal file
34
examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
seed=0, tiled=True,
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
)
|
||||
save_video(video, "video_Wan2.2-Fun-A14B-InP.mp4", fps=15, quality=5)
|
||||
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
32
examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480))
|
||||
|
||||
video = pipe(
|
||||
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
switch_DiT_boundary=0.9,
|
||||
)
|
||||
save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5)
|
||||
72
examples/wanvideo/model_inference/Wan2.2-S2V-14B.py
Normal file
72
examples/wanvideo/model_inference/Wan2.2-S2V-14B.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# This script can generate a single video clip.
|
||||
# If you need generate long videos, please refer to `Wan2.2-S2V-14B_multi_clips.py`.
|
||||
import torch
|
||||
from PIL import Image
|
||||
import librosa
|
||||
from diffsynth.utils.data import VideoData, save_video_with_audio
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
|
||||
)
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_video_dataset",
|
||||
local_dir="./data/example_video_dataset",
|
||||
allow_file_pattern=f"wans2v/*"
|
||||
)
|
||||
|
||||
num_frames = 81 # 4n+1
|
||||
height = 448
|
||||
width = 832
|
||||
|
||||
prompt = "a person is singing"
|
||||
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
|
||||
# s2v audio input, recommend 16kHz sampling rate
|
||||
audio_path = 'data/example_video_dataset/wans2v/sing.MP3'
|
||||
input_audio, sample_rate = librosa.load(audio_path, sr=16000)
|
||||
|
||||
# Speech-to-video
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_sample_rate=sample_rate,
|
||||
input_audio=input_audio,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
save_video_with_audio(video[1:], "video_1_Wan2.2-S2V-14B.mp4", audio_path, fps=16, quality=5)
|
||||
|
||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||
pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'
|
||||
pose_video = VideoData(pose_video_path, height=height, width=width)
|
||||
|
||||
# Speech-to-video with pose
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_sample_rate=sample_rate,
|
||||
input_audio=input_audio,
|
||||
s2v_pose_video=pose_video,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
save_video_with_audio(video[1:], "video_2_Wan2.2-S2V-14B.mp4", audio_path, fps=16, quality=5)
|
||||
116
examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py
Normal file
116
examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import librosa
|
||||
from diffsynth.utils.data import VideoData, save_video_with_audio
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig, WanVideoUnit_S2V
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
def speech_to_video(
|
||||
prompt,
|
||||
input_image,
|
||||
audio_path,
|
||||
negative_prompt="",
|
||||
num_clip=None,
|
||||
audio_sample_rate=16000,
|
||||
pose_video_path=None,
|
||||
infer_frames=80,
|
||||
height=448,
|
||||
width=832,
|
||||
num_inference_steps=40,
|
||||
fps=16, # recommend fixing fps as 16 for s2v
|
||||
motion_frames=73, # hyperparameter of wan2.2-s2v
|
||||
save_path=None,
|
||||
):
|
||||
# s2v audio input, recommend 16kHz sampling rate
|
||||
input_audio, sample_rate = librosa.load(audio_path, sr=audio_sample_rate)
|
||||
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
|
||||
pose_video = VideoData(pose_video_path, height=height, width=width) if pose_video_path is not None else None
|
||||
|
||||
audio_embeds, pose_latents, num_repeat = WanVideoUnit_S2V.pre_calculate_audio_pose(
|
||||
pipe=pipe,
|
||||
input_audio=input_audio,
|
||||
audio_sample_rate=sample_rate,
|
||||
s2v_pose_video=pose_video,
|
||||
num_frames=infer_frames + 1,
|
||||
height=height,
|
||||
width=width,
|
||||
fps=fps,
|
||||
)
|
||||
num_repeat = min(num_repeat, num_clip) if num_clip is not None else num_repeat
|
||||
print(f"Generating {num_repeat} video clips...")
|
||||
motion_videos = []
|
||||
video = []
|
||||
for r in range(num_repeat):
|
||||
s2v_pose_latents = pose_latents[r] if pose_latents is not None else None
|
||||
current_clip = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=0,
|
||||
num_frames=infer_frames + 1,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_embeds=audio_embeds[r],
|
||||
s2v_pose_latents=s2v_pose_latents,
|
||||
motion_video=motion_videos,
|
||||
num_inference_steps=num_inference_steps,
|
||||
)
|
||||
current_clip = current_clip[-infer_frames:]
|
||||
if r == 0:
|
||||
current_clip = current_clip[3:]
|
||||
overlap_frames_num = min(motion_frames, len(current_clip))
|
||||
motion_videos = motion_videos[overlap_frames_num:] + current_clip[-overlap_frames_num:]
|
||||
video.extend(current_clip)
|
||||
save_video_with_audio(video, save_path, audio_path, fps=16, quality=5)
|
||||
print(f"processed the {r+1}th clip of total {num_repeat} clips.")
|
||||
return video
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
|
||||
)
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/example_video_dataset",
|
||||
local_dir="./data/example_video_dataset",
|
||||
allow_file_pattern=f"wans2v/*",
|
||||
)
|
||||
|
||||
infer_frames = 80 # 4n
|
||||
height = 448
|
||||
width = 832
|
||||
|
||||
prompt = "a person is singing"
|
||||
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
|
||||
|
||||
video_with_audio = speech_to_video(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
audio_path='data/example_video_dataset/wans2v/sing.MP3',
|
||||
negative_prompt=negative_prompt,
|
||||
pose_video_path='data/example_video_dataset/wans2v/pose.mp4',
|
||||
save_path="video_full_Wan2.2-S2V-14B.mp4",
|
||||
infer_frames=infer_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
# num_clip means generating only the first n clips with n * infer_frames frames.
|
||||
video_with_audio_pose = speech_to_video(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
audio_path='data/example_video_dataset/wans2v/sing.MP3',
|
||||
negative_prompt=negative_prompt,
|
||||
pose_video_path='data/example_video_dataset/wans2v/pose.mp4',
|
||||
save_path="video_clip_2_Wan2.2-S2V-14B.mp4",
|
||||
num_clip=2
|
||||
)
|
||||
23
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
23
examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5)
|
||||
42
examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py
Normal file
42
examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
height=704, width=1248,
|
||||
num_frames=121,
|
||||
)
|
||||
save_video(video, "video_1_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)
|
||||
|
||||
# Image-to-video
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704))
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
height=704, width=1248,
|
||||
input_image=input_image,
|
||||
num_frames=121,
|
||||
)
|
||||
save_video(video, "video_2_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)
|
||||
53
examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py
Normal file
53
examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="PAI/Wan2.2-VACE-Fun-A14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
|
||||
# Depth video -> Video
|
||||
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_1_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5)
|
||||
|
||||
# Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_2_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5)
|
||||
|
||||
# Depth video + Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
vace_video=control_video,
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video_3_Wan2.2-VACE-Fun-A14B.mp4", fps=15, quality=5)
|
||||
24
examples/wanvideo/model_inference/krea-realtime-video.py
Normal file
24
examples/wanvideo/model_inference/krea-realtime-video.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="a cat sitting on a boat",
|
||||
num_inference_steps=6, num_frames=81,
|
||||
seed=0, tiled=True,
|
||||
cfg_scale=1,
|
||||
sigma_shift=20,
|
||||
)
|
||||
save_video(video, "video_krea-realtime-video.mp4", fps=15, quality=5)
|
||||
Reference in New Issue
Block a user