mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
Compare commits
89 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a10635818a | ||
|
|
9ea29a769d | ||
|
|
2a5355b7cb | ||
|
|
7a06a58f49 | ||
|
|
a3b4f235a0 | ||
|
|
a572254a1d | ||
|
|
9e78bf5e89 | ||
|
|
d21676b4dc | ||
|
|
53f01e72e6 | ||
|
|
55e5e373dd | ||
|
|
4a0921ada1 | ||
|
|
5129d3dc52 | ||
|
|
ee9bab80f2 | ||
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
3dc28f428f | ||
|
|
3c8a3fe2e1 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
0dbb3d333f |
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel && pip install -r requirements.txt
|
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||||
- name: Build DiffSynth
|
- name: Build DiffSynth
|
||||||
run: python setup.py sdist bdist_wheel
|
run: python setup.py sdist bdist_wheel
|
||||||
- name: Publish package to PyPI
|
- name: Publish package to PyPI
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -13,13 +13,19 @@ Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
|||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
Welcome to the magic world of Diffusion models!
|
||||||
|
|
||||||
Until now, DiffSynth Studio has supported the following models:
|
DiffSynth consists of two open-source projects:
|
||||||
|
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||||
|
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||||
|
|
||||||
|
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||||
|
|
||||||
|
Until now, DiffSynth-Studio has supported the following models:
|
||||||
|
|
||||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||||
@@ -36,6 +42,11 @@ Until now, DiffSynth Studio has supported the following models:
|
|||||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||||
|
|
||||||
|
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||||
|
|
||||||
|
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||||
|
|
||||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||||
|
|
||||||
@@ -43,7 +54,7 @@ Until now, DiffSynth Studio has supported the following models:
|
|||||||
|
|
||||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
@@ -72,7 +83,7 @@ Until now, DiffSynth Studio has supported the following models:
|
|||||||
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
||||||
- LoRA, ControlNet, and additional models will be available soon.
|
- LoRA, ControlNet, and additional models will be available soon.
|
||||||
|
|
||||||
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from ..models.flux_text_encoder import FluxTextEncoder2
|
|||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
from ..models.flux_controlnet import FluxControlNet
|
from ..models.flux_controlnet import FluxControlNet
|
||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
from ..models.flux_ipadapter import FluxIpAdapter
|
||||||
|
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||||
|
|
||||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||||
from ..models.cog_dit import CogDiT
|
from ..models.cog_dit import CogDiT
|
||||||
@@ -58,6 +59,7 @@ from ..models.wan_video_dit import WanModel
|
|||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
from ..models.wan_video_vae import WanVideoVAE
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
@@ -95,6 +97,7 @@ model_loader_configs = [
|
|||||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||||
|
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||||
@@ -103,6 +106,8 @@ model_loader_configs = [
|
|||||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
|
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
||||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||||
@@ -116,10 +121,16 @@ model_loader_configs = [
|
|||||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
|
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
|
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -133,6 +144,7 @@ huggingface_model_loader_configs = [
|
|||||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||||
|
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||||
]
|
]
|
||||||
patch_model_loader_configs = [
|
patch_model_loader_configs = [
|
||||||
@@ -595,6 +607,25 @@ preset_models_on_modelscope = {
|
|||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"InfiniteYou":{
|
||||||
|
"file_list":[
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||||
|
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||||
|
],
|
||||||
|
"load_path":[
|
||||||
|
[
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||||
|
],
|
||||||
|
"models/InfiniteYou/image_proj_model.bin",
|
||||||
|
],
|
||||||
|
},
|
||||||
# ESRGAN
|
# ESRGAN
|
||||||
"ESRGAN_x4": [
|
"ESRGAN_x4": [
|
||||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||||
@@ -675,6 +706,25 @@ preset_models_on_modelscope = {
|
|||||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"HunyuanVideoI2V":{
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||||
|
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
||||||
|
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideoI2V/text_encoder_2",
|
||||||
|
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
||||||
|
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
},
|
||||||
"HunyuanVideo-fp8":{
|
"HunyuanVideo-fp8":{
|
||||||
"file_list": [
|
"file_list": [
|
||||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||||
@@ -735,6 +785,7 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||||
|
"InfiniteYou",
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||||
"QwenPrompt",
|
"QwenPrompt",
|
||||||
"OmostPrompt",
|
"OmostPrompt",
|
||||||
@@ -751,4 +802,5 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"StableDiffusion3.5-medium",
|
"StableDiffusion3.5-medium",
|
||||||
"HunyuanVideo",
|
"HunyuanVideo",
|
||||||
"HunyuanVideo-fp8",
|
"HunyuanVideo-fp8",
|
||||||
|
"HunyuanVideoI2V",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,10 +1,4 @@
|
|||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
from controlnet_aux.processor import (
|
|
||||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Processor_id: TypeAlias = Literal[
|
Processor_id: TypeAlias = Literal[
|
||||||
@@ -15,18 +9,25 @@ class Annotator:
|
|||||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||||
if not skip_processor:
|
if not skip_processor:
|
||||||
if processor_id == "canny":
|
if processor_id == "canny":
|
||||||
|
from controlnet_aux.processor import CannyDetector
|
||||||
self.processor = CannyDetector()
|
self.processor = CannyDetector()
|
||||||
elif processor_id == "depth":
|
elif processor_id == "depth":
|
||||||
|
from controlnet_aux.processor import MidasDetector
|
||||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "softedge":
|
elif processor_id == "softedge":
|
||||||
|
from controlnet_aux.processor import HEDdetector
|
||||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "lineart":
|
elif processor_id == "lineart":
|
||||||
|
from controlnet_aux.processor import LineartDetector
|
||||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "lineart_anime":
|
elif processor_id == "lineart_anime":
|
||||||
|
from controlnet_aux.processor import LineartAnimeDetector
|
||||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "openpose":
|
elif processor_id == "openpose":
|
||||||
|
from controlnet_aux.processor import OpenposeDetector
|
||||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "normal":
|
elif processor_id == "normal":
|
||||||
|
from controlnet_aux.processor import NormalBaeDetector
|
||||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
|||||||
125
diffsynth/data/image_pulse.py
Normal file
125
diffsynth/data/image_pulse.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import torch, os, json, torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SingleTaskDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None):
|
||||||
|
self.base_path = base_path
|
||||||
|
self.keys = keys
|
||||||
|
self.metadata = []
|
||||||
|
self.bad_data = []
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.random = random
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.image_process = v2.Compose([
|
||||||
|
v2.CenterCrop(size=(height, width)),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
if metadata_path is None:
|
||||||
|
self.search_for_data("", report_data_log=True)
|
||||||
|
self.report_data_log()
|
||||||
|
else:
|
||||||
|
with open(metadata_path, "r", encoding="utf-8-sig") as f:
|
||||||
|
self.metadata = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def report_data_log(self):
|
||||||
|
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
|
||||||
|
|
||||||
|
|
||||||
|
def dump_metadata(self, path):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_file(self, absolute_path, relative_path):
|
||||||
|
data_list = []
|
||||||
|
with open(absolute_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
for image_1, image_2, instruction in self.keys:
|
||||||
|
image_1 = os.path.join(relative_path, metadata[image_1])
|
||||||
|
image_2 = os.path.join(relative_path, metadata[image_2])
|
||||||
|
instruction = metadata[instruction]
|
||||||
|
data_list.append((image_1, image_2, instruction))
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def search_for_data(self, path, report_data_log=False):
|
||||||
|
now_path = os.path.join(self.base_path, path)
|
||||||
|
if os.path.isfile(now_path) and path.endswith(".json"):
|
||||||
|
try:
|
||||||
|
data_list = self.parse_json_file(now_path, os.path.dirname(path))
|
||||||
|
self.metadata.extend(data_list)
|
||||||
|
except:
|
||||||
|
self.bad_data.append(now_path)
|
||||||
|
elif os.path.isdir(now_path):
|
||||||
|
for sub_path in os.listdir(now_path):
|
||||||
|
self.search_for_data(os.path.join(path, sub_path))
|
||||||
|
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
|
||||||
|
self.report_data_log()
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(self, image_path):
|
||||||
|
image_path = os.path.join(self.base_path, image_path)
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(self.width / width, self.height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
image = self.image_process(image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(self, data_id):
|
||||||
|
image_1, image_2, instruction = self.metadata[data_id]
|
||||||
|
image_1 = self.load_image(image_1)
|
||||||
|
image_2 = self.load_image(image_2)
|
||||||
|
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
if self.random:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
|
||||||
|
data = self.load_data(data_id)
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return self.load_data(data_id)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch if self.random else len(self.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
|
||||||
|
self.dataset_list = dataset_list
|
||||||
|
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
|
||||||
|
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
|
||||||
|
data = self.dataset_list[dataset_id][data_id]
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch
|
||||||
0
diffsynth/distributed/__init__.py
Normal file
0
diffsynth/distributed/__init__.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from einops import rearrange
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
|
|
||||||
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
|
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||||
|
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
return x.to(position.dtype)
|
||||||
|
|
||||||
|
def pad_freqs(original_tensor, target_len):
|
||||||
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
|
pad_size = target_len - seq_len
|
||||||
|
padding_tensor = torch.ones(
|
||||||
|
pad_size,
|
||||||
|
s1,
|
||||||
|
s2,
|
||||||
|
dtype=original_tensor.dtype,
|
||||||
|
device=original_tensor.device)
|
||||||
|
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def rope_apply(x, freqs, num_heads):
|
||||||
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
s_per_rank = x.shape[1]
|
||||||
|
|
||||||
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
sp_rank = get_sequence_parallel_rank()
|
||||||
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
|
|
||||||
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
def usp_dit_forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
if self.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = self.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x, (f, h, w) = self.patchify(x)
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = torch.chunk(
|
||||||
|
x, get_sequence_parallel_world_size(),
|
||||||
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
x = self.head(x, t)
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def usp_attn_forward(self, x, freqs):
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(x))
|
||||||
|
v = self.v(x)
|
||||||
|
|
||||||
|
q = rope_apply(q, freqs, self.num_heads)
|
||||||
|
k = rope_apply(k, freqs, self.num_heads)
|
||||||
|
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
|
||||||
|
x = xFuserLongContextAttention()(
|
||||||
|
None,
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
)
|
||||||
|
x = x.flatten(2)
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return self.o(x)
|
||||||
@@ -5,7 +5,7 @@ import pathlib
|
|||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from turtle import forward
|
# from turtle import forward
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -318,6 +318,8 @@ class FluxControlNetStateDictConverter:
|
|||||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||||
|
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||||
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||||
else:
|
else:
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
return state_dict_, extra_kwargs
|
return state_dict_, extra_kwargs
|
||||||
|
|||||||
@@ -20,10 +20,11 @@ class RoPEEmbedding(torch.nn.Module):
|
|||||||
self.axes_dim = axes_dim
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
|
||||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor:
|
||||||
assert dim % 2 == 0, "The dimension must be even."
|
assert dim % 2 == 0, "The dimension must be even."
|
||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
scale = scale.to(device)
|
||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
|
|
||||||
batch_size, seq_length = pos.shape
|
batch_size, seq_length = pos.shape
|
||||||
@@ -36,9 +37,9 @@ class RoPEEmbedding(torch.nn.Module):
|
|||||||
return out.float()
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, ids):
|
def forward(self, ids, device="cpu"):
|
||||||
n_axes = ids.shape[-1]
|
n_axes = ids.shape[-1]
|
||||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3)
|
||||||
return emb.unsqueeze(1)
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
@@ -628,19 +629,22 @@ class FluxDiTStateDictConverter:
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
for name in list(state_dict_.keys()):
|
for name in list(state_dict_.keys()):
|
||||||
if ".proj_in_besides_attn." in name:
|
if "single_blocks." in name and ".a_to_q." in name:
|
||||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||||
|
if mlp is None:
|
||||||
|
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||||
|
*state_dict_[name].shape[1:],
|
||||||
|
dtype=state_dict_[name].dtype)
|
||||||
|
else:
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||||
param = torch.concat([
|
param = torch.concat([
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
state_dict_.pop(name),
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
state_dict_[name],
|
mlp,
|
||||||
], dim=0)
|
], dim=0)
|
||||||
|
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||||
state_dict_[name_] = param
|
state_dict_[name_] = param
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
|
||||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
|
||||||
state_dict_.pop(name)
|
|
||||||
for name in list(state_dict_.keys()):
|
for name in list(state_dict_.keys()):
|
||||||
for component in ["a", "b"]:
|
for component in ["a", "b"]:
|
||||||
if f".{component}_to_q." in name:
|
if f".{component}_to_q." in name:
|
||||||
|
|||||||
128
diffsynth/models/flux_infiniteyou.py
Normal file
128
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
def FeedForward(dim, mult=4):
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(inner_dim, dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_tensor(x, heads):
|
||||||
|
bs, length, width = x.shape
|
||||||
|
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
|
x = x.view(bs, length, heads, -1)
|
||||||
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||||
|
x = x.reshape(bs, heads, length, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.heads = heads
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, latents):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): image features
|
||||||
|
shape (b, n1, D)
|
||||||
|
latent (torch.Tensor): latent features
|
||||||
|
shape (b, n2, D)
|
||||||
|
"""
|
||||||
|
x = self.norm1(x)
|
||||||
|
latents = self.norm2(latents)
|
||||||
|
|
||||||
|
b, l, _ = latents.shape
|
||||||
|
|
||||||
|
q = self.to_q(latents)
|
||||||
|
kv_input = torch.cat((x, latents), dim=-2)
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
q = reshape_tensor(q, self.heads)
|
||||||
|
k = reshape_tensor(k, self.heads)
|
||||||
|
v = reshape_tensor(v, self.heads)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||||
|
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
out = weight @ v
|
||||||
|
|
||||||
|
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class InfiniteYouImageProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=1280,
|
||||||
|
depth=4,
|
||||||
|
dim_head=64,
|
||||||
|
heads=20,
|
||||||
|
num_queries=8,
|
||||||
|
embedding_dim=512,
|
||||||
|
output_dim=4096,
|
||||||
|
ff_mult=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||||
|
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||||
|
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim)
|
||||||
|
self.norm_out = nn.LayerNorm(output_dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
latents = attn(x, latents) + latents
|
||||||
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
|
latents = self.proj_out(latents)
|
||||||
|
return self.norm_out(latents)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict['image_proj']
|
||||||
31
diffsynth/models/flux_reference_embedder.py
Normal file
31
diffsynth/models/flux_reference_embedder.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from .sd3_dit import TimestepEmbeddings
|
||||||
|
from .flux_dit import RoPEEmbedding
|
||||||
|
import torch
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
|
||||||
|
class FluxReferenceEmbedder(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
self.idx_embedder = TimestepEmbeddings(256, 256)
|
||||||
|
self.proj = torch.nn.Linear(3072, 3072)
|
||||||
|
|
||||||
|
def forward(self, image_ids, idx, dtype, device):
|
||||||
|
pos_emb = self.pos_embedder(image_ids, device=device)
|
||||||
|
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
|
||||||
|
length = pos_emb.shape[2]
|
||||||
|
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
|
||||||
|
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
|
||||||
|
image_rotary_emb = pos_emb + idx_emb
|
||||||
|
return image_rotary_emb
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.idx_embedder.timestep_embedder[-1].load_state_dict({
|
||||||
|
"weight": torch.zeros((256, 256)),
|
||||||
|
"bias": torch.zeros((256,))
|
||||||
|
}),
|
||||||
|
self.proj.load_state_dict({
|
||||||
|
"weight": torch.eye(3072),
|
||||||
|
"bias": torch.zeros((3072,))
|
||||||
|
})
|
||||||
@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Union, Tuple, List
|
from typing import Union, Tuple, List
|
||||||
|
from .utils import hash_state_dict_keys
|
||||||
|
|
||||||
|
|
||||||
def HunyuanVideoRope(latents):
|
def HunyuanVideoRope(latents):
|
||||||
@@ -236,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
|
|||||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SingleTokenRefiner(torch.nn.Module):
|
class SingleTokenRefiner(torch.nn.Module):
|
||||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||||
@@ -269,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
|
|||||||
x = block(x, c, mask)
|
x = block(x, c, mask)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ModulateDiT(torch.nn.Module):
|
class ModulateDiT(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, factor=6):
|
def __init__(self, hidden_size, factor=6):
|
||||||
@@ -279,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear(self.act(x))
|
return self.linear(self.act(x))
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, shift=None, scale=None):
|
|
||||||
|
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||||
|
if tr_shift is not None:
|
||||||
|
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||||
|
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
x = torch.concat((x_zero, x_orig), dim=1)
|
||||||
|
return x
|
||||||
if scale is None and shift is None:
|
if scale is None and shift is None:
|
||||||
return x
|
return x
|
||||||
elif shift is None:
|
elif shift is None:
|
||||||
@@ -290,7 +296,7 @@ def modulate(x, shift=None, scale=None):
|
|||||||
return x + shift.unsqueeze(1)
|
return x + shift.unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
def reshape_for_broadcast(
|
def reshape_for_broadcast(
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
@@ -343,7 +349,7 @@ def rotate_half(x):
|
|||||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||||
) # [B, S, H, D//2]
|
) # [B, S, H, D//2]
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
def apply_rotary_emb(
|
||||||
xq: torch.Tensor,
|
xq: torch.Tensor,
|
||||||
@@ -385,6 +391,15 @@ def attention(q, k, v):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||||
|
if tr_gate is not None:
|
||||||
|
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||||
|
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||||
|
return torch.concat((x_zero, x_orig), dim=1)
|
||||||
|
else:
|
||||||
|
return x * gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
|||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||||
|
if token_replace_vec is not None:
|
||||||
|
assert tr_token is not None
|
||||||
|
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||||
|
else:
|
||||||
|
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||||
|
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||||
|
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
qkv = self.to_qkv(norm_hidden_states)
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
|
||||||
@@ -418,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
|||||||
|
|
||||||
if freqs_cis is not None:
|
if freqs_cis is not None:
|
||||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||||
|
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||||
|
|
||||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||||
|
|
||||||
def process_ff(self, hidden_states, attn_output, mod):
|
|
||||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
if mod_tr is not None:
|
||||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||||
|
else:
|
||||||
|
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||||
|
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||||
|
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||||
|
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MMDoubleStreamBlock(torch.nn.Module):
|
class MMDoubleStreamBlock(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
|
|||||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||||
|
|
||||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
attn_output_a = attention(q_a, k_a, v_a)
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
attn_output_b = attention(q_b, k_b, v_b)
|
||||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||||
|
|
||||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||||
return hidden_states_a, hidden_states_b
|
return hidden_states_a, hidden_states_b
|
||||||
|
|
||||||
@@ -488,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
|
|||||||
|
|
||||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||||
return x + output * mod_gate.unsqueeze(1)
|
return x + output * mod_gate.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class MMSingleStreamBlock(torch.nn.Module):
|
class MMSingleStreamBlock(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||||
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
|||||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||||
|
if token_replace_vec is not None:
|
||||||
|
assert tr_token is not None
|
||||||
|
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||||
|
else:
|
||||||
|
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||||
|
|
||||||
norm_hidden_states = self.norm(hidden_states)
|
norm_hidden_states = self.norm(hidden_states)
|
||||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||||
|
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||||
qkv = self.to_qkv(norm_hidden_states)
|
qkv = self.to_qkv(norm_hidden_states)
|
||||||
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
|||||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||||
|
|
||||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
v_len = txt_len - split_token
|
||||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||||
|
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||||
|
|
||||||
attn_output_a = attention(q_a, k_a, v_a)
|
attn_output_a = attention(q_a, k_a, v_a)
|
||||||
attn_output_b = attention(q_b, k_b, v_b)
|
attn_output_b = attention(q_b, k_b, v_b)
|
||||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiT(torch.nn.Module):
|
class HunyuanVideoDiT(torch.nn.Module):
|
||||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||||
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
torch.nn.SiLU(),
|
torch.nn.SiLU(),
|
||||||
torch.nn.Linear(hidden_size, hidden_size)
|
torch.nn.Linear(hidden_size, hidden_size)
|
||||||
)
|
)
|
||||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||||
self.final_layer = FinalLayer(hidden_size)
|
self.final_layer = FinalLayer(hidden_size)
|
||||||
@@ -580,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
def unpatchify(self, x, T, H, W):
|
def unpatchify(self, x, T, H, W):
|
||||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||||
self.warm_device = warm_device
|
self.warm_device = warm_device
|
||||||
self.cold_device = cold_device
|
self.cold_device = cold_device
|
||||||
@@ -610,10 +642,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
B, C, T, H, W = x.shape
|
B, C, T, H, W = x.shape
|
||||||
|
|
||||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||||
|
if self.guidance_in is not None:
|
||||||
|
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||||
img = self.img_in(x)
|
img = self.img_in(x)
|
||||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||||
|
|
||||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||||
|
|
||||||
@@ -625,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
img = self.final_layer(img, vec)
|
img = self.final_layer(img, vec)
|
||||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||||
@@ -681,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
del x_, weight_, bias_
|
del x_, weight_, bias_
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return y_
|
return y_
|
||||||
|
|
||||||
def block_forward(self, x, **kwargs):
|
def block_forward(self, x, **kwargs):
|
||||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||||
@@ -689,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def forward(self, hidden_states, **kwargs):
|
def forward(self, hidden_states, **kwargs):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||||
@@ -711,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||||
hidden_states = hidden_states * weight
|
hidden_states = hidden_states * weight
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
class Conv3d(torch.nn.Conv3d):
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm):
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.weight is not None and self.bias is not None:
|
if self.weight is not None and self.bias is not None:
|
||||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
if isinstance(module, torch.nn.Linear):
|
if isinstance(module, torch.nn.Linear):
|
||||||
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
|||||||
return HunyuanVideoDiTStateDictConverter()
|
return HunyuanVideoDiTStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoDiTStateDictConverter:
|
class HunyuanVideoDiTStateDictConverter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
|
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
if "module" in state_dict:
|
if "module" in state_dict:
|
||||||
state_dict = state_dict["module"]
|
state_dict = state_dict["module"]
|
||||||
direct_dict = {
|
direct_dict = {
|
||||||
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
|
|||||||
state_dict_[name_] = param
|
state_dict_[name_] = param
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|||||||
@@ -1,24 +1,18 @@
|
|||||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.auto_offload = False
|
self.auto_offload = False
|
||||||
|
|
||||||
|
|
||||||
def enable_auto_offload(self, **kwargs):
|
def enable_auto_offload(self, **kwargs):
|
||||||
self.auto_offload = True
|
self.auto_offload = True
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
hidden_state_skip_layer=2
|
|
||||||
):
|
|
||||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||||
inputs_embeds = embed_tokens(input_ids)
|
inputs_embeds = embed_tokens(input_ids)
|
||||||
|
|
||||||
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
|
|||||||
break
|
break
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.auto_offload = False
|
||||||
|
|
||||||
|
def enable_auto_offload(self, **kwargs):
|
||||||
|
self.auto_offload = True
|
||||||
|
|
||||||
|
# TODO: implement the low VRAM inference for MLLM.
|
||||||
|
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||||
|
outputs = super().forward(input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
pixel_values=pixel_values)
|
||||||
|
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||||
|
return hidden_state
|
||||||
|
|||||||
@@ -195,70 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
"txt.mod": "txt_mod",
|
"txt.mod": "txt_mod",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
class GeneralLoRAFromPeft:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||||
|
|
||||||
|
|
||||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
def get_name_dict(self, lora_state_dict):
|
||||||
device, torch_dtype = None, None
|
lora_name_dict = {}
|
||||||
for name, param in state_dict.items():
|
for key in lora_state_dict:
|
||||||
device, torch_dtype = param.device, param.dtype
|
|
||||||
break
|
|
||||||
return device, torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
|
||||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
if ".lora_B." not in key:
|
||||||
continue
|
continue
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
keys = key.split(".")
|
||||||
if len(keys) > keys.index("lora_B") + 2:
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
keys.pop(keys.index("lora_B") + 1)
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
keys.pop(keys.index("lora_B"))
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
target_name = ".".join(keys)
|
target_name = ".".join(keys)
|
||||||
if target_name not in target_state_dict:
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
return {}
|
return lora_name_dict
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||||
|
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||||
|
if matched_num == len(lora_name_dict):
|
||||||
|
return "", ""
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_device_and_dtype(self, state_dict):
|
||||||
|
device, dtype = None, None
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
device, dtype = param.device, param.dtype
|
||||||
|
break
|
||||||
|
computation_device = device
|
||||||
|
computation_dtype = dtype
|
||||||
|
if computation_device == torch.device("cpu"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
computation_device = torch.device("cuda")
|
||||||
|
if computation_dtype == torch.float8_e4m3fn:
|
||||||
|
computation_dtype = torch.float32
|
||||||
|
return device, dtype, computation_device, computation_dtype
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||||
state_dict_model = model.state_dict()
|
state_dict_model = model.state_dict()
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||||
if len(state_dict_lora) > 0:
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
for name in lora_name_dict:
|
||||||
for name in state_dict_lora:
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||||
dtype=state_dict_model[name].dtype,
|
if len(weight_up.shape) == 4:
|
||||||
device=state_dict_model[name].device
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
)
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
model.load_state_dict(state_dict_model)
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||||
|
weight_patched = weight_model + weight_lora
|
||||||
|
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||||
|
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||||
|
model.load_state_dict(state_dict_model)
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for model_class in self.supported_model_classes:
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
|
||||||
if len(state_dict_lora_) > 0:
|
|
||||||
return "", ""
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||||
@@ -362,7 +365,22 @@ class FluxLoRAConverter:
|
|||||||
else:
|
else:
|
||||||
state_dict_[name] = param
|
state_dict_[name] = param
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
class WanLoRAConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_opensource_format(state_dict, **kwargs):
|
||||||
|
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||||
|
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
@@ -376,6 +376,7 @@ class ModelManager:
|
|||||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||||
else:
|
else:
|
||||||
print(f"Loading LoRA models from file: {file_path}")
|
print(f"Loading LoRA models from file: {file_path}")
|
||||||
|
is_loaded = False
|
||||||
if len(state_dict) == 0:
|
if len(state_dict) == 0:
|
||||||
state_dict = load_state_dict(file_path)
|
state_dict = load_state_dict(file_path)
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||||
@@ -385,7 +386,10 @@ class ModelManager:
|
|||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||||
lora_prefix, model_resource = match_results
|
lora_prefix, model_resource = match_results
|
||||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||||
|
is_loaded = True
|
||||||
break
|
break
|
||||||
|
if not is_loaded:
|
||||||
|
print(f" Cannot load LoRA: {file_path}")
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
|
|||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return super().forward(x.float()).type_as(x)
|
return super().forward(x).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x: [B, L, C].
|
x: [B, L, C].
|
||||||
"""
|
"""
|
||||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
p = self.attn_dropout if self.training else 0.0
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
|
||||||
x = x.reshape(b, s, c)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
|
|||||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||||
|
|
||||||
# compute query, key, value
|
# compute query, key, value
|
||||||
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = flash_attention(q, k, v, version=2)
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
x = x.reshape(b, 1, c)
|
x = x.reshape(b, 1, c)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
|
|||||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
|
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||||
|
videos = videos.to(dtype)
|
||||||
out = self.model.visual(videos, use_31_block=True)
|
out = self.model.visual(videos, use_31_block=True)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
44
diffsynth/models/wan_video_motion_controller.py
Normal file
44
diffsynth/models/wan_video_motion_controller.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .wan_video_dit import sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanMotionControllerModel(torch.nn.Module):
|
||||||
|
def __init__(self, freq_dim=256, dim=1536):
|
||||||
|
super().__init__()
|
||||||
|
self.freq_dim = freq_dim
|
||||||
|
self.linear = nn.Sequential(
|
||||||
|
nn.Linear(freq_dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim * 6),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, motion_bucket_id):
|
||||||
|
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
||||||
|
emb = self.linear(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
state_dict = self.linear[-1].state_dict()
|
||||||
|
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||||
|
self.linear[-1].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanMotionControllerModelDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanMotionControllerModelDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
|
|||||||
target_w: target_w + hidden_states_batch.shape[4],
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
] += mask
|
] += mask
|
||||||
values = values / weight
|
values = values / weight
|
||||||
values = values.float().clamp_(-1, 1)
|
values = values.clamp_(-1, 1)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
|
|||||||
target_w: target_w + hidden_states_batch.shape[4],
|
target_w: target_w + hidden_states_batch.shape[4],
|
||||||
] += mask
|
] += mask
|
||||||
values = values / weight
|
values = values / weight
|
||||||
values = values.float()
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
def single_encode(self, video, device):
|
def single_encode(self, video, device):
|
||||||
video = video.to(device)
|
video = video.to(device)
|
||||||
x = self.model.encode(video, self.scale)
|
x = self.model.encode(video, self.scale)
|
||||||
return x.float()
|
return x
|
||||||
|
|
||||||
|
|
||||||
def single_decode(self, hidden_state, device):
|
def single_decode(self, hidden_state, device):
|
||||||
hidden_state = hidden_state.to(device)
|
hidden_state = hidden_state.to(device)
|
||||||
video = self.model.decode(hidden_state, self.scale)
|
video = self.model.decode(hidden_state, self.scale)
|
||||||
return video.float().clamp_(-1, 1)
|
return video.clamp_(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||||
|
from ..models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
from ..prompters import FluxPrompter
|
from ..prompters import FluxPrompter
|
||||||
from ..schedulers import FlowMatchScheduler
|
from ..schedulers import FlowMatchScheduler
|
||||||
from .base import BasePipeline
|
from .base import BasePipeline
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -31,6 +33,8 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.controlnet: FluxMultiControlNetManager = None
|
self.controlnet: FluxMultiControlNetManager = None
|
||||||
self.ipadapter: FluxIpAdapter = None
|
self.ipadapter: FluxIpAdapter = None
|
||||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||||
|
self.infinityou_processor: InfinitYou = None
|
||||||
|
self.reference_embedder: FluxReferenceEmbedder = None
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||||
|
|
||||||
|
|
||||||
@@ -162,6 +166,11 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||||
|
|
||||||
|
# InfiniteYou
|
||||||
|
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||||
|
if self.image_proj_model is not None:
|
||||||
|
self.infinityou_processor = InfinitYou(device=self.device)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||||
@@ -347,6 +356,27 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||||
|
if self.infinityou_processor is not None and id_image is not None:
|
||||||
|
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||||
|
else:
|
||||||
|
return {}, controlnet_image
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_reference_images(self, reference_images, tiled=False, tile_size=64, tile_stride=32):
|
||||||
|
if reference_images is not None:
|
||||||
|
hidden_states_ref = []
|
||||||
|
for reference_image in reference_images:
|
||||||
|
self.load_models_to_device(['vae_encoder'])
|
||||||
|
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
latents = self.encode_image(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
hidden_states_ref.append(latents)
|
||||||
|
hidden_states_ref = torch.concat(hidden_states_ref, dim=0)
|
||||||
|
return {"hidden_states_ref": hidden_states_ref}
|
||||||
|
else:
|
||||||
|
return {"hidden_states_ref": None}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -382,6 +412,11 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
eligen_entity_masks=None,
|
eligen_entity_masks=None,
|
||||||
enable_eligen_on_negative=False,
|
enable_eligen_on_negative=False,
|
||||||
enable_eligen_inpaint=False,
|
enable_eligen_inpaint=False,
|
||||||
|
# InfiniteYou
|
||||||
|
infinityou_id_image=None,
|
||||||
|
infinityou_guidance=1.0,
|
||||||
|
# Reference images
|
||||||
|
reference_images=None,
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -409,6 +444,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# Extra input
|
# Extra input
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||||
|
|
||||||
|
# InfiniteYou
|
||||||
|
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
|
||||||
|
|
||||||
# Entity control
|
# Entity control
|
||||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||||
|
|
||||||
@@ -417,6 +455,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
# ControlNets
|
# ControlNets
|
||||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||||
|
|
||||||
|
# Reference images
|
||||||
|
reference_kwargs = self.prepare_reference_images(reference_images, **tiler_kwargs)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||||
@@ -428,9 +469,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
|
|
||||||
# Positive side
|
# Positive side
|
||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **reference_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
@@ -445,9 +486,9 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
# Negative side
|
# Negative side
|
||||||
noise_pred_nega = lets_dance_flux(
|
noise_pred_nega = lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **reference_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -467,6 +508,58 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# Offload all models
|
# Offload all models
|
||||||
self.load_models_to_device([])
|
self.load_models_to_device([])
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class InfinitYou:
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
from facexlib.recognition import init_recognition_model
|
||||||
|
from insightface.app import FaceAnalysis
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
insightface_root_path = 'models/InfiniteYou/insightface'
|
||||||
|
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
|
||||||
|
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
||||||
|
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
||||||
|
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
||||||
|
|
||||||
|
def _detect_face(self, id_image_cv2):
|
||||||
|
face_info = self.app_640.get(id_image_cv2)
|
||||||
|
if len(face_info) > 0:
|
||||||
|
return face_info
|
||||||
|
face_info = self.app_320.get(id_image_cv2)
|
||||||
|
if len(face_info) > 0:
|
||||||
|
return face_info
|
||||||
|
face_info = self.app_160.get(id_image_cv2)
|
||||||
|
return face_info
|
||||||
|
|
||||||
|
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
||||||
|
from insightface.utils import face_align
|
||||||
|
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
||||||
|
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||||
|
arc_face_image = 2 * arc_face_image - 1
|
||||||
|
arc_face_image = arc_face_image.contiguous().to(self.device)
|
||||||
|
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
||||||
|
return face_emb
|
||||||
|
|
||||||
|
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||||
|
import cv2
|
||||||
|
if id_image is None:
|
||||||
|
return {'id_emb': None}, controlnet_image
|
||||||
|
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
|
||||||
|
face_info = self._detect_face(id_image_cv2)
|
||||||
|
if len(face_info) == 0:
|
||||||
|
raise ValueError('No face detected in the input ID image')
|
||||||
|
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
||||||
|
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
||||||
|
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
||||||
|
if controlnet_image is None:
|
||||||
|
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||||
|
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
@@ -515,6 +608,7 @@ class TeaCache:
|
|||||||
def lets_dance_flux(
|
def lets_dance_flux(
|
||||||
dit: FluxDiT,
|
dit: FluxDiT,
|
||||||
controlnet: FluxMultiControlNetManager = None,
|
controlnet: FluxMultiControlNetManager = None,
|
||||||
|
reference_embedder: FluxReferenceEmbedder = None,
|
||||||
hidden_states=None,
|
hidden_states=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
prompt_emb=None,
|
prompt_emb=None,
|
||||||
@@ -523,13 +617,17 @@ def lets_dance_flux(
|
|||||||
text_ids=None,
|
text_ids=None,
|
||||||
image_ids=None,
|
image_ids=None,
|
||||||
controlnet_frames=None,
|
controlnet_frames=None,
|
||||||
|
hidden_states_ref=None,
|
||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
tile_stride=64,
|
tile_stride=64,
|
||||||
entity_prompt_emb=None,
|
entity_prompt_emb=None,
|
||||||
entity_masks=None,
|
entity_masks=None,
|
||||||
ipadapter_kwargs_list={},
|
ipadapter_kwargs_list={},
|
||||||
|
id_emb=None,
|
||||||
|
infinityou_guidance=None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if tiled:
|
if tiled:
|
||||||
@@ -573,6 +671,9 @@ def lets_dance_flux(
|
|||||||
"tile_size": tile_size,
|
"tile_size": tile_size,
|
||||||
"tile_stride": tile_stride,
|
"tile_stride": tile_stride,
|
||||||
}
|
}
|
||||||
|
if id_emb is not None:
|
||||||
|
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
|
||||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||||
controlnet_frames, **controlnet_extra_kwargs
|
controlnet_frames, **controlnet_extra_kwargs
|
||||||
)
|
)
|
||||||
@@ -593,28 +694,55 @@ def lets_dance_flux(
|
|||||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||||
else:
|
else:
|
||||||
prompt_emb = dit.context_embedder(prompt_emb)
|
prompt_emb = dit.context_embedder(prompt_emb)
|
||||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
|
# Reference images
|
||||||
|
if hidden_states_ref is not None:
|
||||||
|
# RoPE
|
||||||
|
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
|
||||||
|
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
|
||||||
|
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
|
||||||
|
# hidden_states
|
||||||
|
original_hidden_states_length = hidden_states.shape[1]
|
||||||
|
hidden_states_ref = dit.patchify(hidden_states_ref)
|
||||||
|
hidden_states_ref = dit.x_embedder(hidden_states_ref)
|
||||||
|
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
|
||||||
|
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
|
||||||
|
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
||||||
else:
|
else:
|
||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
hidden_states = tea_cache.update(hidden_states)
|
hidden_states = tea_cache.update(hidden_states)
|
||||||
else:
|
else:
|
||||||
# Joint Blocks
|
# Joint Blocks
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
hidden_states, prompt_emb = block(
|
if use_gradient_checkpointing:
|
||||||
hidden_states,
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
prompt_emb,
|
create_custom_forward(block),
|
||||||
conditioning,
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
|
||||||
image_rotary_emb,
|
use_reentrant=False,
|
||||||
attention_mask,
|
)
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
else:
|
||||||
)
|
hidden_states, prompt_emb = block(
|
||||||
|
hidden_states,
|
||||||
|
prompt_emb,
|
||||||
|
conditioning,
|
||||||
|
image_rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||||
|
)
|
||||||
# ControlNet
|
# ControlNet
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
if controlnet is not None and controlnet_frames is not None:
|
||||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||||
@@ -623,14 +751,21 @@ def lets_dance_flux(
|
|||||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||||
num_joint_blocks = len(dit.blocks)
|
num_joint_blocks = len(dit.blocks)
|
||||||
for block_id, block in enumerate(dit.single_blocks):
|
for block_id, block in enumerate(dit.single_blocks):
|
||||||
hidden_states, prompt_emb = block(
|
if use_gradient_checkpointing:
|
||||||
hidden_states,
|
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||||
prompt_emb,
|
create_custom_forward(block),
|
||||||
conditioning,
|
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||||
image_rotary_emb,
|
use_reentrant=False,
|
||||||
attention_mask,
|
)
|
||||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
else:
|
||||||
)
|
hidden_states, prompt_emb = block(
|
||||||
|
hidden_states,
|
||||||
|
prompt_emb,
|
||||||
|
conditioning,
|
||||||
|
image_rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||||
|
)
|
||||||
# ControlNet
|
# ControlNet
|
||||||
if controlnet is not None and controlnet_frames is not None:
|
if controlnet is not None and controlnet_frames is not None:
|
||||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||||
@@ -639,6 +774,8 @@ def lets_dance_flux(
|
|||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(hidden_states)
|
tea_cache.store(hidden_states)
|
||||||
|
|
||||||
|
if hidden_states_ref is not None:
|
||||||
|
hidden_states = hidden_states[:, :original_hidden_states_length]
|
||||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||||
hidden_states = dit.final_proj_out(hidden_states)
|
hidden_states = dit.final_proj_out(hidden_states)
|
||||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
|
|||||||
from .base import BasePipeline
|
from .base import BasePipeline
|
||||||
from ..prompters import HunyuanVideoPrompter
|
from ..prompters import HunyuanVideoPrompter
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoPipeline(BasePipeline):
|
class HunyuanVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||||
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
pipe.enable_vram_management()
|
pipe.enable_vram_management()
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
|
||||||
|
num_patches = round((base_size / patch_size)**2)
|
||||||
|
assert max_ratio >= 1.0
|
||||||
|
crop_size_list = []
|
||||||
|
wp, hp = num_patches, 1
|
||||||
|
while wp > 0:
|
||||||
|
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
||||||
|
crop_size_list.append((wp * patch_size, hp * patch_size))
|
||||||
|
if (hp + 1) * wp <= num_patches:
|
||||||
|
hp += 1
|
||||||
|
else:
|
||||||
|
wp -= 1
|
||||||
|
return crop_size_list
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
|
||||||
|
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
|
||||||
|
aspect_ratio = float(height) / float(width)
|
||||||
|
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
|
||||||
|
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||||
|
return buckets[closest_ratio_id], float(closest_ratio)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
|
||||||
|
if i2v_resolution == "720p":
|
||||||
|
bucket_hw_base_size = 960
|
||||||
|
elif i2v_resolution == "540p":
|
||||||
|
bucket_hw_base_size = 720
|
||||||
|
elif i2v_resolution == "360p":
|
||||||
|
bucket_hw_base_size = 480
|
||||||
|
else:
|
||||||
|
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
||||||
|
origin_size = semantic_images[0].size
|
||||||
|
|
||||||
|
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
|
||||||
|
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
|
||||||
|
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
||||||
|
ref_image_transform = transforms.Compose([
|
||||||
|
transforms.Resize(closest_size),
|
||||||
|
transforms.CenterCrop(closest_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5])
|
||||||
|
])
|
||||||
|
|
||||||
|
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
||||||
|
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
||||||
|
target_height, target_width = closest_size
|
||||||
|
return semantic_image_pixel_values, target_height, target_width
|
||||||
|
|
||||||
|
|
||||||
|
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
|
||||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
|
||||||
)
|
)
|
||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||||
|
|
||||||
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
prompt,
|
prompt,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
input_video=None,
|
input_video=None,
|
||||||
|
input_images=None,
|
||||||
|
i2v_resolution="720p",
|
||||||
|
i2v_stability=True,
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
seed=None,
|
seed=None,
|
||||||
rand_device=None,
|
rand_device=None,
|
||||||
@@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
):
|
):
|
||||||
# Tiler parameters
|
# Tiler parameters
|
||||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||||
|
|
||||||
|
# encoder input images
|
||||||
|
if input_images is not None:
|
||||||
|
self.load_models_to_device(['vae_encoder'])
|
||||||
|
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
|
||||||
|
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
|
||||||
|
image_latents = self.vae_encoder(image_pixel_values)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
rand_device = self.device if rand_device is None else rand_device
|
rand_device = self.device if rand_device is None else rand_device
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||||
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
input_video = torch.stack(input_video, dim=2)
|
input_video = torch.stack(input_video, dim=2)
|
||||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
|
elif input_images is not None and i2v_stability:
|
||||||
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
|
||||||
|
t = torch.tensor([0.999]).to(device=self.device)
|
||||||
|
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
|
||||||
|
latents = latents.to(dtype=image_latents.dtype)
|
||||||
else:
|
else:
|
||||||
latents = noise
|
latents = noise
|
||||||
|
|
||||||
# Encode prompts
|
# Encode prompts
|
||||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
# current mllm does not support vram_management
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
|
||||||
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||||
|
|
||||||
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||||
|
|
||||||
|
forward_func = lets_dance_hunyuan_video
|
||||||
|
if input_images is not None:
|
||||||
|
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
|
||||||
|
forward_func = lets_dance_hunyuan_video_i2v
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
|
|||||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
if input_images is not None:
|
||||||
|
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
|
||||||
|
latents = torch.concat([image_latents, latents], dim=2)
|
||||||
|
else:
|
||||||
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae_decoder'])
|
self.load_models_to_device(['vae_decoder'])
|
||||||
@@ -194,7 +267,7 @@ class TeaCache:
|
|||||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
else:
|
else:
|
||||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||||
rescale_func = np.poly1d(coefficients)
|
rescale_func = np.poly1d(coefficients)
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||||
@@ -203,14 +276,14 @@ class TeaCache:
|
|||||||
else:
|
else:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
self.previous_modulated_input = modulated_inp
|
self.previous_modulated_input = modulated_inp
|
||||||
self.step += 1
|
self.step += 1
|
||||||
if self.step == self.num_inference_steps:
|
if self.step == self.num_inference_steps:
|
||||||
self.step = 0
|
self.step = 0
|
||||||
if should_calc:
|
if should_calc:
|
||||||
self.previous_hidden_states = img.clone()
|
self.previous_hidden_states = img.clone()
|
||||||
return not should_calc
|
return not should_calc
|
||||||
|
|
||||||
def store(self, hidden_states):
|
def store(self, hidden_states):
|
||||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||||
self.previous_hidden_states = None
|
self.previous_hidden_states = None
|
||||||
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
|
|||||||
print("TeaCache skip forward.")
|
print("TeaCache skip forward.")
|
||||||
img = tea_cache.update(img)
|
img = tea_cache.update(img)
|
||||||
else:
|
else:
|
||||||
|
split_token = int(text_mask.sum(dim=1))
|
||||||
|
txt_len = int(txt.shape[1])
|
||||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
|
||||||
|
|
||||||
x = torch.concat([img, txt], dim=1)
|
x = torch.concat([img, txt], dim=1)
|
||||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
|
||||||
img = x[:, :-256]
|
img = x[:, :-txt_len]
|
||||||
|
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache.store(img)
|
||||||
|
img = dit.final_layer(img, vec)
|
||||||
|
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def lets_dance_hunyuan_video_i2v(
|
||||||
|
dit: HunyuanVideoDiT,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
prompt_emb: torch.Tensor = None,
|
||||||
|
text_mask: torch.Tensor = None,
|
||||||
|
pooled_prompt_emb: torch.Tensor = None,
|
||||||
|
freqs_cos: torch.Tensor = None,
|
||||||
|
freqs_sin: torch.Tensor = None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
tea_cache: TeaCache = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
# Uncomment below to keep same as official implementation
|
||||||
|
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
|
||||||
|
vec = dit.time_in(t, dtype=torch.bfloat16)
|
||||||
|
vec_2 = dit.vector_in(pooled_prompt_emb)
|
||||||
|
vec = vec + vec_2
|
||||||
|
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
|
||||||
|
tr_token = (H // 2) * (W // 2)
|
||||||
|
token_replace_vec = token_replace_vec + vec_2
|
||||||
|
|
||||||
|
img = dit.img_in(x)
|
||||||
|
txt = dit.txt_in(prompt_emb, t, text_mask)
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache_update = tea_cache.check(dit, img, vec)
|
||||||
|
else:
|
||||||
|
tea_cache_update = False
|
||||||
|
|
||||||
|
if tea_cache_update:
|
||||||
|
print("TeaCache skip forward.")
|
||||||
|
img = tea_cache.update(img)
|
||||||
|
else:
|
||||||
|
split_token = int(text_mask.sum(dim=1))
|
||||||
|
txt_len = int(txt.shape[1])
|
||||||
|
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||||
|
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
|
||||||
|
|
||||||
|
x = torch.concat([img, txt], dim=1)
|
||||||
|
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||||
|
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
|
||||||
|
img = x[:, :-txt_len]
|
||||||
|
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(img)
|
tea_cache.store(img)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import types
|
||||||
from ..models import ModelManager
|
from ..models import ModelManager
|
||||||
from ..models.wan_video_dit import WanModel
|
from ..models.wan_video_dit import WanModel
|
||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
@@ -11,11 +12,13 @@ from einops import rearrange
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||||
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||||
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -29,9 +32,11 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.image_encoder: WanImageEncoder = None
|
self.image_encoder: WanImageEncoder = None
|
||||||
self.dit: WanModel = None
|
self.dit: WanModel = None
|
||||||
self.vae: WanVideoVAE = None
|
self.vae: WanVideoVAE = None
|
||||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
self.motion_controller: WanMotionControllerModel = None
|
||||||
|
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
|
||||||
self.height_division_factor = 16
|
self.height_division_factor = 16
|
||||||
self.width_division_factor = 16
|
self.width_division_factor = 16
|
||||||
|
self.use_unified_sequence_parallel = False
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
@@ -60,8 +65,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
torch.nn.Linear: AutoWrappedLinear,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
torch.nn.Conv3d: AutoWrappedModule,
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
WanLayerNorm: AutoWrappedModule,
|
RMSNorm: AutoWrappedModule,
|
||||||
WanRMSNorm: AutoWrappedModule,
|
|
||||||
},
|
},
|
||||||
module_config = dict(
|
module_config = dict(
|
||||||
offload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
@@ -116,7 +120,23 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
offload_device="cpu",
|
offload_device="cpu",
|
||||||
onload_dtype=dtype,
|
onload_dtype=dtype,
|
||||||
onload_device="cpu",
|
onload_device="cpu",
|
||||||
computation_dtype=self.torch_dtype,
|
computation_dtype=dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.motion_controller is not None:
|
||||||
|
dtype = next(iter(self.motion_controller.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.motion_controller,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device="cpu",
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device="cpu",
|
||||||
|
computation_dtype=dtype,
|
||||||
computation_device=self.device,
|
computation_device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -132,14 +152,24 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||||
|
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||||
if device is None: device = model_manager.device
|
if device is None: device = model_manager.device
|
||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
pipe.fetch_models(model_manager)
|
pipe.fetch_models(model_manager)
|
||||||
|
if use_usp:
|
||||||
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
|
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
||||||
|
|
||||||
|
for block in pipe.dit.blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
||||||
|
pipe.sp_size = get_sequence_parallel_world_size()
|
||||||
|
pipe.use_unified_sequence_parallel = True
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -148,22 +178,51 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True):
|
def encode_prompt(self, prompt, positive=True):
|
||||||
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
|
||||||
return {"context": prompt_emb}
|
return {"context": prompt_emb}
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, num_frames, height, width):
|
def encode_image(self, image, end_image, num_frames, height, width):
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
clip_context = self.image_encoder.encode_image([image])
|
||||||
clip_context = self.image_encoder.encode_image([image])
|
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
msk[:, 1:] = 0
|
||||||
msk[:, 1:] = 0
|
if end_image is not None:
|
||||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||||
msk = msk.transpose(1, 2)[0]
|
msk[:, -1:] = 1
|
||||||
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
else:
|
||||||
y = torch.concat([msk, y])
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||||
return {"clip_fea": clip_context, "y": [y]}
|
|
||||||
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
|
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||||
|
y = torch.concat([msk, y])
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return {"clip_feature": clip_context, "y": y}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
control_video = self.preprocess_images(control_video)
|
||||||
|
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
if control_video is not None:
|
||||||
|
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
|
if clip_feature is None or y is None:
|
||||||
|
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
|
||||||
|
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
y = y[:, -16:]
|
||||||
|
y = torch.concat([control_latents, y], dim=1)
|
||||||
|
return {"clip_feature": clip_feature, "y": y}
|
||||||
|
|
||||||
|
|
||||||
def tensor2video(self, frames):
|
def tensor2video(self, frames):
|
||||||
@@ -174,19 +233,26 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None):
|
def prepare_extra_input(self, latents=None):
|
||||||
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_unified_sequence_parallel(self):
|
||||||
|
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||||
|
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return {"motion_bucket_id": motion_bucket_id}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -195,7 +261,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
prompt,
|
prompt,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
input_image=None,
|
input_image=None,
|
||||||
|
end_image=None,
|
||||||
input_video=None,
|
input_video=None,
|
||||||
|
control_video=None,
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
seed=None,
|
seed=None,
|
||||||
rand_device="cpu",
|
rand_device="cpu",
|
||||||
@@ -205,9 +273,12 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
cfg_scale=5.0,
|
cfg_scale=5.0,
|
||||||
num_inference_steps=50,
|
num_inference_steps=50,
|
||||||
sigma_shift=5.0,
|
sigma_shift=5.0,
|
||||||
|
motion_bucket_id=None,
|
||||||
tiled=True,
|
tiled=True,
|
||||||
tile_size=(30, 52),
|
tile_size=(30, 52),
|
||||||
tile_stride=(15, 26),
|
tile_stride=(15, 26),
|
||||||
|
tea_cache_l1_thresh=None,
|
||||||
|
tea_cache_model_id="",
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
@@ -221,15 +292,16 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||||
|
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||||
if input_video is not None:
|
if input_video is not None:
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
input_video = self.preprocess_images(input_video)
|
input_video = self.preprocess_images(input_video)
|
||||||
input_video = torch.stack(input_video, dim=2)
|
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||||
else:
|
else:
|
||||||
latents = noise
|
latents = noise
|
||||||
@@ -243,29 +315,56 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Encode image
|
# Encode image
|
||||||
if input_image is not None and self.image_encoder is not None:
|
if input_image is not None and self.image_encoder is not None:
|
||||||
self.load_models_to_device(["image_encoder", "vae"])
|
self.load_models_to_device(["image_encoder", "vae"])
|
||||||
image_emb = self.encode_image(input_image, num_frames, height, width)
|
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
|
||||||
else:
|
else:
|
||||||
image_emb = {}
|
image_emb = {}
|
||||||
|
|
||||||
|
# ControlNet
|
||||||
|
if control_video is not None:
|
||||||
|
self.load_models_to_device(["image_encoder", "vae"])
|
||||||
|
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
|
||||||
|
|
||||||
|
# Motion Controller
|
||||||
|
if self.motion_controller is not None and motion_bucket_id is not None:
|
||||||
|
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
|
||||||
|
else:
|
||||||
|
motion_kwargs = {}
|
||||||
|
|
||||||
# Extra input
|
# Extra input
|
||||||
extra_input = self.prepare_extra_input(latents)
|
extra_input = self.prepare_extra_input(latents)
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
|
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
|
# Unified Sequence Parallel
|
||||||
|
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(["dit"])
|
self.load_models_to_device(["dit", "motion_controller"])
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
noise_pred_posi = model_fn_wan_video(
|
||||||
if cfg_scale != 1.0:
|
self.dit, motion_controller=self.motion_controller,
|
||||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
x=latents, timestep=timestep,
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
**prompt_emb_posi, **image_emb, **extra_input,
|
||||||
else:
|
**tea_cache_posi, **usp_kwargs, **motion_kwargs
|
||||||
noise_pred = noise_pred_posi
|
)
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
noise_pred_nega = model_fn_wan_video(
|
||||||
|
self.dit, motion_controller=self.motion_controller,
|
||||||
|
x=latents, timestep=timestep,
|
||||||
|
**prompt_emb_nega, **image_emb, **extra_input,
|
||||||
|
**tea_cache_nega, **usp_kwargs, **motion_kwargs
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
@@ -274,3 +373,121 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
frames = self.tensor2video(frames[0])
|
frames = self.tensor2video(frames[0])
|
||||||
|
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TeaCache:
|
||||||
|
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
self.step = 0
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
self.previous_modulated_input = None
|
||||||
|
self.rel_l1_thresh = rel_l1_thresh
|
||||||
|
self.previous_residual = None
|
||||||
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
|
self.coefficients_dict = {
|
||||||
|
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||||
|
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||||
|
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||||
|
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||||
|
}
|
||||||
|
if model_id not in self.coefficients_dict:
|
||||||
|
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||||
|
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||||
|
self.coefficients = self.coefficients_dict[model_id]
|
||||||
|
|
||||||
|
def check(self, dit: WanModel, x, t_mod):
|
||||||
|
modulated_inp = t_mod.clone()
|
||||||
|
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
else:
|
||||||
|
coefficients = self.coefficients
|
||||||
|
rescale_func = np.poly1d(coefficients)
|
||||||
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
|
should_calc = False
|
||||||
|
else:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
self.previous_modulated_input = modulated_inp
|
||||||
|
self.step += 1
|
||||||
|
if self.step == self.num_inference_steps:
|
||||||
|
self.step = 0
|
||||||
|
if should_calc:
|
||||||
|
self.previous_hidden_states = x.clone()
|
||||||
|
return not should_calc
|
||||||
|
|
||||||
|
def store(self, hidden_states):
|
||||||
|
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||||
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
|
def update(self, hidden_states):
|
||||||
|
hidden_states = hidden_states + self.previous_residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_wan_video(
|
||||||
|
dit: WanModel,
|
||||||
|
motion_controller: WanMotionControllerModel = None,
|
||||||
|
x: torch.Tensor = None,
|
||||||
|
timestep: torch.Tensor = None,
|
||||||
|
context: torch.Tensor = None,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
tea_cache: TeaCache = None,
|
||||||
|
use_unified_sequence_parallel: bool = False,
|
||||||
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
import torch.distributed as dist
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
|
||||||
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||||
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||||
|
if motion_bucket_id is not None and motion_controller is not None:
|
||||||
|
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||||
|
context = dit.text_embedding(context)
|
||||||
|
|
||||||
|
if dit.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = dit.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x, (f, h, w) = dit.patchify(x)
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||||
|
else:
|
||||||
|
tea_cache_update = False
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
if tea_cache_update:
|
||||||
|
x = tea_cache.update(x)
|
||||||
|
else:
|
||||||
|
for block in dit.blocks:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
if tea_cache is not None:
|
||||||
|
tea_cache.store(x)
|
||||||
|
|
||||||
|
x = dit.head(x, t)
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
|
return x
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from .base_prompter import BasePrompter
|
from .base_prompter import BasePrompter
|
||||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
|
||||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
|
||||||
import os, torch
|
import os, torch
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
PROMPT_TEMPLATE_ENCODE = (
|
PROMPT_TEMPLATE_ENCODE = (
|
||||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||||
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
|||||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE_ENCODE_I2V = (
|
||||||
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
|
||||||
|
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||||
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||||
|
"1. The main content and theme of the video."
|
||||||
|
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||||
|
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||||
|
"4. background environment, light, style and atmosphere."
|
||||||
|
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
PROMPT_TEMPLATE = {
|
PROMPT_TEMPLATE = {
|
||||||
"dit-llm-encode": {
|
"dit-llm-encode": {
|
||||||
"template": PROMPT_TEMPLATE_ENCODE,
|
"template": PROMPT_TEMPLATE_ENCODE,
|
||||||
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
|
|||||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||||
"crop_start": 95,
|
"crop_start": 95,
|
||||||
},
|
},
|
||||||
|
"dit-llm-encode-i2v": {
|
||||||
|
"template": PROMPT_TEMPLATE_ENCODE_I2V,
|
||||||
|
"crop_start": 36,
|
||||||
|
"image_emb_start": 5,
|
||||||
|
"image_emb_end": 581,
|
||||||
|
"image_emb_len": 576,
|
||||||
|
"double_return_token_id": 271
|
||||||
|
},
|
||||||
|
"dit-llm-encode-video-i2v": {
|
||||||
|
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||||
|
"crop_start": 103,
|
||||||
|
"image_emb_start": 5,
|
||||||
|
"image_emb_end": 581,
|
||||||
|
"image_emb_len": 576,
|
||||||
|
"double_return_token_id": 271
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||||
@@ -56,9 +91,20 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||||
|
|
||||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
def fetch_models(self,
|
||||||
|
text_encoder_1: SD3TextEncoder1 = None,
|
||||||
|
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
|
||||||
self.text_encoder_1 = text_encoder_1
|
self.text_encoder_1 = text_encoder_1
|
||||||
self.text_encoder_2 = text_encoder_2
|
self.text_encoder_2 = text_encoder_2
|
||||||
|
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
|
||||||
|
# processor
|
||||||
|
# TODO: may need to replace processor with local implementation
|
||||||
|
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||||
|
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
|
||||||
|
# template
|
||||||
|
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
|
||||||
|
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
|
||||||
|
|
||||||
def apply_text_to_template(self, text, template):
|
def apply_text_to_template(self, text, template):
|
||||||
assert isinstance(template, str)
|
assert isinstance(template, str)
|
||||||
@@ -107,8 +153,89 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
|
|
||||||
return last_hidden_state, attention_mask
|
return last_hidden_state, attention_mask
|
||||||
|
|
||||||
|
def encode_prompt_using_mllm(self,
|
||||||
|
prompt,
|
||||||
|
images,
|
||||||
|
max_length,
|
||||||
|
device,
|
||||||
|
crop_start,
|
||||||
|
hidden_state_skip_layer=2,
|
||||||
|
use_attention_mask=True,
|
||||||
|
image_embed_interleave=4):
|
||||||
|
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
|
||||||
|
max_length += crop_start
|
||||||
|
inputs = self.tokenizer_2(prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True)
|
||||||
|
input_ids = inputs.input_ids.to(device)
|
||||||
|
attention_mask = inputs.attention_mask.to(device)
|
||||||
|
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
hidden_state_skip_layer=hidden_state_skip_layer,
|
||||||
|
pixel_values=image_outputs)
|
||||||
|
|
||||||
|
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||||
|
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
|
||||||
|
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
|
||||||
|
batch_indices, last_double_return_token_indices = torch.where(
|
||||||
|
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
|
||||||
|
if last_double_return_token_indices.shape[0] == 3:
|
||||||
|
# in case the prompt is too long
|
||||||
|
last_double_return_token_indices = torch.cat((
|
||||||
|
last_double_return_token_indices,
|
||||||
|
torch.tensor([input_ids.shape[-1]]),
|
||||||
|
))
|
||||||
|
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
|
||||||
|
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
|
||||||
|
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
|
||||||
|
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
|
||||||
|
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||||
|
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
|
||||||
|
attention_mask_assistant_crop_end = last_double_return_token_indices
|
||||||
|
text_last_hidden_state = []
|
||||||
|
text_attention_mask = []
|
||||||
|
image_last_hidden_state = []
|
||||||
|
image_attention_mask = []
|
||||||
|
for i in range(input_ids.shape[0]):
|
||||||
|
text_last_hidden_state.append(
|
||||||
|
torch.cat([
|
||||||
|
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
|
||||||
|
last_hidden_state[i, assistant_crop_end[i].item():],
|
||||||
|
]))
|
||||||
|
text_attention_mask.append(
|
||||||
|
torch.cat([
|
||||||
|
attention_mask[
|
||||||
|
i,
|
||||||
|
crop_start:attention_mask_assistant_crop_start[i].item(),
|
||||||
|
],
|
||||||
|
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
|
||||||
|
]) if use_attention_mask else None)
|
||||||
|
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
|
||||||
|
image_attention_mask.append(
|
||||||
|
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
|
||||||
|
to(attention_mask.dtype) if use_attention_mask else None)
|
||||||
|
|
||||||
|
text_last_hidden_state = torch.stack(text_last_hidden_state)
|
||||||
|
text_attention_mask = torch.stack(text_attention_mask)
|
||||||
|
image_last_hidden_state = torch.stack(image_last_hidden_state)
|
||||||
|
image_attention_mask = torch.stack(image_attention_mask)
|
||||||
|
|
||||||
|
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
|
||||||
|
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
|
||||||
|
|
||||||
|
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
|
||||||
|
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
|
||||||
|
|
||||||
|
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
|
||||||
|
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
|
||||||
|
|
||||||
|
return last_hidden_state, attention_mask
|
||||||
|
|
||||||
def encode_prompt(self,
|
def encode_prompt(self,
|
||||||
prompt,
|
prompt,
|
||||||
|
images=None,
|
||||||
positive=True,
|
positive=True,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
clip_sequence_length=77,
|
clip_sequence_length=77,
|
||||||
@@ -116,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
data_type='video',
|
data_type='video',
|
||||||
use_template=True,
|
use_template=True,
|
||||||
hidden_state_skip_layer=2,
|
hidden_state_skip_layer=2,
|
||||||
use_attention_mask=True):
|
use_attention_mask=True,
|
||||||
|
image_embed_interleave=4):
|
||||||
|
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
prompt = self.process_prompt(prompt, positive=positive)
|
||||||
|
|
||||||
@@ -136,8 +264,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
|||||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
if images is None:
|
||||||
prompt_formated, llm_sequence_length, device, crop_start,
|
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
|
||||||
hidden_state_skip_layer, use_attention_mask)
|
hidden_state_skip_layer, use_attention_mask)
|
||||||
|
else:
|
||||||
|
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
|
||||||
|
crop_start, hidden_state_skip_layer, use_attention_mask,
|
||||||
|
image_embed_interleave)
|
||||||
|
|
||||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||||
|
|||||||
@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
|
|||||||
mask = mask.to(device)
|
mask = mask.to(device)
|
||||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||||
prompt_emb = self.text_encoder(ids, mask)
|
prompt_emb = self.text_encoder(ids, mask)
|
||||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
for i, v in enumerate(seq_lens):
|
||||||
|
prompt_emb[:, v:] = 0
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class FlowMatchScheduler():
|
|||||||
self.linear_timesteps_weights = bsmntw_weighing
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
|
||||||
|
|
||||||
def step(self, model_output, timestep, sample, to_final=False):
|
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||||
if isinstance(timestep, torch.Tensor):
|
if isinstance(timestep, torch.Tensor):
|
||||||
timestep = timestep.cpu()
|
timestep = timestep.cpu()
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
{
|
||||||
|
"_valid_processor_keys": [
|
||||||
|
"images",
|
||||||
|
"do_resize",
|
||||||
|
"size",
|
||||||
|
"resample",
|
||||||
|
"do_center_crop",
|
||||||
|
"crop_size",
|
||||||
|
"do_rescale",
|
||||||
|
"rescale_factor",
|
||||||
|
"do_normalize",
|
||||||
|
"image_mean",
|
||||||
|
"image_std",
|
||||||
|
"do_convert_rgb",
|
||||||
|
"return_tensors",
|
||||||
|
"data_format",
|
||||||
|
"input_data_format"
|
||||||
|
],
|
||||||
|
"crop_size": {
|
||||||
|
"height": 336,
|
||||||
|
"width": 336
|
||||||
|
},
|
||||||
|
"do_center_crop": true,
|
||||||
|
"do_convert_rgb": true,
|
||||||
|
"do_normalize": true,
|
||||||
|
"do_rescale": true,
|
||||||
|
"do_resize": true,
|
||||||
|
"image_mean": [
|
||||||
|
0.48145466,
|
||||||
|
0.4578275,
|
||||||
|
0.40821073
|
||||||
|
],
|
||||||
|
"image_processor_type": "CLIPImageProcessor",
|
||||||
|
"image_std": [
|
||||||
|
0.26862954,
|
||||||
|
0.26130258,
|
||||||
|
0.27577711
|
||||||
|
],
|
||||||
|
"processor_class": "LlavaProcessor",
|
||||||
|
"resample": 3,
|
||||||
|
"rescale_factor": 0.00392156862745098,
|
||||||
|
"size": {
|
||||||
|
"shortest_edge": 336
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -290,7 +290,7 @@ def launch_training_task(model, args):
|
|||||||
name="diffsynth_studio",
|
name="diffsynth_studio",
|
||||||
config=swanlab_config,
|
config=swanlab_config,
|
||||||
mode=args.swanlab_mode,
|
mode=args.swanlab_mode,
|
||||||
logdir=args.output_path,
|
logdir=os.path.join(args.output_path, "swanlog"),
|
||||||
)
|
)
|
||||||
logger = [swanlab_logger]
|
logger = [swanlab_logger]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ We propose EliGen, a novel approach that leverages fine-grained entity-level inf
|
|||||||
|
|
||||||
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||||
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||||
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||||
|
|
||||||
@@ -77,6 +77,11 @@ Demonstration of the styled entity control results with EliGen and IP-Adapter, s
|
|||||||
|-|-|-|-|
|
|-|-|-|-|
|
||||||
|||||
|
|||||
|
||||||
|
|
||||||
|
We also provide a demo of the styled entity control results with EliGen and specific styled lora, see [./styled_entity_control.py](./styled_entity_control.py) for details. Here is the visualization of EliGen with [Lego dreambooth lora](https://huggingface.co/merve/flux-lego-lora-dreambooth).
|
||||||
|
|||||
|
||||||
|
|-|-|-|-|
|
||||||
|
|||||
|
||||||
|
|
||||||
### Entity Transfer
|
### Entity Transfer
|
||||||
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.
|
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.
|
||||||
|
|
||||||
|
|||||||
@@ -27,11 +27,20 @@ def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
|||||||
|
|
||||||
# download and load model
|
# download and load model
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
# set download_from_modelscope = False if you want to download model from huggingface
|
||||||
|
download_from_modelscope = True
|
||||||
|
if download_from_modelscope:
|
||||||
|
model_id = "DiffSynth-Studio/Eligen"
|
||||||
|
downloading_priority = ["ModelScope"]
|
||||||
|
else:
|
||||||
|
model_id = "modelscope/EliGen"
|
||||||
|
downloading_priority = ["HuggingFace"]
|
||||||
model_manager.load_lora(
|
model_manager.load_lora(
|
||||||
download_customized_models(
|
download_customized_models(
|
||||||
model_id="DiffSynth-Studio/Eligen",
|
model_id=model_id,
|
||||||
origin_file_path="model_bf16.safetensors",
|
origin_file_path="model_bf16.safetensors",
|
||||||
local_dir="models/lora/entity_control"
|
local_dir="models/lora/entity_control",
|
||||||
|
downloading_priority=downloading_priority
|
||||||
),
|
),
|
||||||
lora_alpha=1
|
lora_alpha=1
|
||||||
)
|
)
|
||||||
|
|||||||
90
examples/EntityControl/styled_entity_control.py
Normal file
90
examples/EntityControl/styled_entity_control.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from examples.EntityControl.utils import visualize_masks
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||||
|
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||||
|
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||||
|
for seed in seeds:
|
||||||
|
# generate image
|
||||||
|
image = pipe(
|
||||||
|
prompt=global_prompt,
|
||||||
|
cfg_scale=3.0,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_inference_steps=50,
|
||||||
|
embedded_guidance=3.5,
|
||||||
|
seed=seed,
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
eligen_entity_prompts=entity_prompts,
|
||||||
|
eligen_entity_masks=masks,
|
||||||
|
)
|
||||||
|
image.save(f"styled_eligen_example_{example_id}_{seed}.png")
|
||||||
|
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_example_{example_id}_mask_{seed}.png")
|
||||||
|
|
||||||
|
# download and load model
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
model_manager.load_lora(
|
||||||
|
download_customized_models(
|
||||||
|
model_id="FluxLora/merve-flux-lego-lora-dreambooth",
|
||||||
|
origin_file_path="pytorch_lora_weights.safetensors",
|
||||||
|
local_dir="models/lora/merve-flux-lego-lora-dreambooth"
|
||||||
|
),
|
||||||
|
lora_alpha=1
|
||||||
|
)
|
||||||
|
model_manager.load_lora(
|
||||||
|
download_customized_models(
|
||||||
|
model_id="DiffSynth-Studio/Eligen",
|
||||||
|
origin_file_path="model_bf16.safetensors",
|
||||||
|
local_dir="models/lora/entity_control"
|
||||||
|
),
|
||||||
|
lora_alpha=1
|
||||||
|
)
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
# example 1
|
||||||
|
trigger_word = "lego set in style of TOK, "
|
||||||
|
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||||
|
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 2
|
||||||
|
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||||
|
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 3
|
||||||
|
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||||
|
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 4
|
||||||
|
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||||
|
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 5
|
||||||
|
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||||
|
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 6
|
||||||
|
global_prompt = "Snow White and the 6 Dwarfs."
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||||
|
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||||
|
|
||||||
|
# example 7, same prompt with different seeds
|
||||||
|
seeds = range(5, 9)
|
||||||
|
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||||
|
global_prompt = trigger_word + global_prompt
|
||||||
|
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||||
|
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||||
@@ -8,6 +8,12 @@
|
|||||||
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||||
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
||||||
|
|
||||||
|
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|
||||||
|
|VRAM required|Example script|Frames|Resolution|Note|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|
||||||
|
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||||
|
|
||||||
## Gallery
|
## Gallery
|
||||||
|
|
||||||
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
||||||
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
|||||||
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
||||||
|
|
||||||
|
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a
|
||||||
|
|||||||
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
download_models(["HunyuanVideoI2V"])
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
|
# The DiT model is loaded in bfloat16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The other modules are loaded in float16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideoI2V/text_encoder_2",
|
||||||
|
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
# The computation device is "cuda".
|
||||||
|
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
enable_vram_management=True)
|
||||||
|
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||||
|
|
||||||
|
i2v_resolution = "720p"
|
||||||
|
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||||
|
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||||
|
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||||
|
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)
|
||||||
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
download_models(["HunyuanVideoI2V"])
|
||||||
|
model_manager = ModelManager()
|
||||||
|
|
||||||
|
# The DiT model is loaded in bfloat16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The other modules are loaded in float16.
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||||
|
"models/HunyuanVideoI2V/text_encoder_2",
|
||||||
|
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device="cuda"
|
||||||
|
)
|
||||||
|
# The computation device is "cuda".
|
||||||
|
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
enable_vram_management=False)
|
||||||
|
# Although you have enough VRAM, we still recommend you to enable offload.
|
||||||
|
pipe.enable_cpu_offload()
|
||||||
|
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||||
|
|
||||||
|
i2v_resolution = "720p"
|
||||||
|
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||||
|
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||||
|
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||||
|
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)
|
||||||
7
examples/InfiniteYou/README.md
Normal file
7
examples/InfiniteYou/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
|
||||||
|
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
|
||||||
|
|
||||||
|
|Identity Image|Generated Image|
|
||||||
|
|-|-|
|
||||||
|
|||
|
||||||
|
|||
|
||||||
58
examples/InfiniteYou/infiniteyou.py
Normal file
58
examples/InfiniteYou/infiniteyou.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import importlib
|
||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, FluxImagePipeline, download_models, ControlNetConfigUnit
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
if importlib.util.find_spec("facexlib") is None:
|
||||||
|
raise ImportError("You are using InifiniteYou. It depends on facexlib, which is not installed. Please install it with `pip install facexlib`.")
|
||||||
|
if importlib.util.find_spec("insightface") is None:
|
||||||
|
raise ImportError("You are using InifiniteYou. It depends on insightface, which is not installed. Please install it with `pip install insightface`.")
|
||||||
|
|
||||||
|
download_models(["InfiniteYou"])
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
|
model_manager.load_models([
|
||||||
|
[
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||||
|
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||||
|
],
|
||||||
|
"models/InfiniteYou/image_proj_model.bin",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
pipe = FluxImagePipeline.from_model_manager(
|
||||||
|
model_manager,
|
||||||
|
controlnet_config_units=[
|
||||||
|
ControlNetConfigUnit(
|
||||||
|
processor_id="none",
|
||||||
|
model_path=[
|
||||||
|
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors',
|
||||||
|
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors'
|
||||||
|
],
|
||||||
|
scale=1.0
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/infiniteyou/*")
|
||||||
|
|
||||||
|
prompt = "A man, portrait, cinematic"
|
||||||
|
id_image = "data/examples/infiniteyou/man.jpg"
|
||||||
|
id_image = Image.open(id_image).convert('RGB')
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt, seed=1,
|
||||||
|
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||||
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
|
height=1024, width=1024,
|
||||||
|
)
|
||||||
|
image.save("man.jpg")
|
||||||
|
|
||||||
|
prompt = "A woman, portrait, cinematic"
|
||||||
|
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||||
|
id_image = Image.open(id_image).convert('RGB')
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt, seed=1,
|
||||||
|
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||||
|
num_inference_steps=50, embedded_guidance=3.5,
|
||||||
|
height=1024, width=1024,
|
||||||
|
)
|
||||||
|
image.save("woman.jpg")
|
||||||
@@ -10,32 +10,52 @@ cd DiffSynth-Studio
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
|
## Model Zoo
|
||||||
|
|
||||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
|Developer|Name|Link|Scripts|
|
||||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
|-|-|-|-|
|
||||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|
||||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|
||||||
|
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||||
|
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||||
|
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|
||||||
|
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|
||||||
|
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|
||||||
|
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|
||||||
|
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||||
|
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||||
|
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||||
|
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||||
|
|
||||||
## Inference
|
Base model features
|
||||||
|
|
||||||
### Wan-Video-1.3B-T2V
|
||Text-to-video|Image-to-video|End frame|Control|
|
||||||
|
|-|-|-|-|-|
|
||||||
|
|1.3B text-to-video|✅||||
|
||||||
|
|14B text-to-video|✅||||
|
||||||
|
|14B image-to-video 480P||✅|||
|
||||||
|
|14B image-to-video 720P||✅|||
|
||||||
|
|1.3B InP||✅|✅||
|
||||||
|
|14B InP||✅|✅||
|
||||||
|
|1.3B Control||||✅|
|
||||||
|
|14B Control||||✅|
|
||||||
|
|
||||||
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
|
Adapter model compatibility
|
||||||
|
|
||||||
Required VRAM: 6G
|
||1.3B text-to-video|1.3B InP|
|
||||||
|
|-|-|-|
|
||||||
|
|1.3B aesthetics LoRA|✅||
|
||||||
|
|1.3B Highres-fix LoRA|✅||
|
||||||
|
|1.3B ExVideo LoRA|✅||
|
||||||
|
|1.3B Speed Control adapter|✅|✅|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
## VRAM Usage
|
||||||
|
|
||||||
Put sunglasses on the dog.
|
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
|
||||||
|
|
||||||
### Wan-Video-14B-T2V
|
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|
||||||
|
|
||||||
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
|
||||||
|
|
||||||
We present a detailed table here. The model is tested on a single A100.
|
|
||||||
|
|
||||||
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
||||||
|-|-|-|-|-|
|
|-|-|-|-|-|
|
||||||
@@ -45,15 +65,46 @@ We present a detailed table here. The model is tested on a single A100.
|
|||||||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
||||||
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
||||||
|
|
||||||
|
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
||||||
|
|
||||||
|
## Efficient Attention Implementation
|
||||||
|
|
||||||
|
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
|
||||||
|
|
||||||
|
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||||
|
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||||
|
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||||
|
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
||||||
|
|
||||||
|
## Acceleration
|
||||||
|
|
||||||
|
We support multiple acceleration solutions:
|
||||||
|
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
|
||||||
|
|
||||||
|
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install xfuser>=0.4.3
|
||||||
|
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
|
||||||
|
|
||||||
|
## Gallery
|
||||||
|
|
||||||
|
1.3B text-to-video.
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||||
|
|
||||||
|
Put sunglasses on the dog.
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||||
|
|
||||||
|
14B text-to-video.
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||||
|
|
||||||
### Wan-Video-14B-I2V
|
14B image-to-video.
|
||||||
|
|
||||||
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
|
||||||
|
|
||||||
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
||||||
|
|
||||||
@@ -155,6 +206,12 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
|||||||
--use_gradient_checkpointing
|
--use_gradient_checkpointing
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
|
||||||
|
|
||||||
|
If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`.
|
||||||
|
|
||||||
|
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
|
||||||
|
|
||||||
Step 5: Test
|
Step 5: Test
|
||||||
|
|
||||||
Test LoRA:
|
Test LoRA:
|
||||||
|
|||||||
@@ -7,11 +7,12 @@ from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
|||||||
from peft import LoraConfig, inject_adapter_in_model
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextVideoDataset(torch.utils.data.Dataset):
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
|
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
|
||||||
metadata = pd.read_csv(metadata_path)
|
metadata = pd.read_csv(metadata_path)
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
self.text = metadata["text"].to_list()
|
self.text = metadata["text"].to_list()
|
||||||
@@ -21,6 +22,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
self.is_i2v = is_i2v
|
||||||
|
|
||||||
self.frame_process = v2.Compose([
|
self.frame_process = v2.Compose([
|
||||||
v2.CenterCrop(size=(height, width)),
|
v2.CenterCrop(size=(height, width)),
|
||||||
@@ -48,10 +50,13 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
frames = []
|
frames = []
|
||||||
|
first_frame = None
|
||||||
for frame_id in range(num_frames):
|
for frame_id in range(num_frames):
|
||||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||||
frame = Image.fromarray(frame)
|
frame = Image.fromarray(frame)
|
||||||
frame = self.crop_and_resize(frame)
|
frame = self.crop_and_resize(frame)
|
||||||
|
if first_frame is None:
|
||||||
|
first_frame = np.array(frame)
|
||||||
frame = frame_process(frame)
|
frame = frame_process(frame)
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
reader.close()
|
reader.close()
|
||||||
@@ -59,7 +64,10 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
frames = torch.stack(frames, dim=0)
|
frames = torch.stack(frames, dim=0)
|
||||||
frames = rearrange(frames, "T C H W -> C T H W")
|
frames = rearrange(frames, "T C H W -> C T H W")
|
||||||
|
|
||||||
return frames
|
if self.is_i2v:
|
||||||
|
return frames, first_frame
|
||||||
|
else:
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def load_video(self, file_path):
|
def load_video(self, file_path):
|
||||||
@@ -70,7 +78,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def is_image(self, file_path):
|
def is_image(self, file_path):
|
||||||
file_ext_name = file_path.split(".")[-1]
|
file_ext_name = file_path.split(".")[-1]
|
||||||
if file_ext_name.lower() in ["jpg", "png", "webp"]:
|
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -78,6 +86,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
def load_image(self, file_path):
|
def load_image(self, file_path):
|
||||||
frame = Image.open(file_path).convert("RGB")
|
frame = Image.open(file_path).convert("RGB")
|
||||||
frame = self.crop_and_resize(frame)
|
frame = self.crop_and_resize(frame)
|
||||||
|
first_frame = frame
|
||||||
frame = self.frame_process(frame)
|
frame = self.frame_process(frame)
|
||||||
frame = rearrange(frame, "C H W -> C 1 H W")
|
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||||
return frame
|
return frame
|
||||||
@@ -87,10 +96,16 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
text = self.text[data_id]
|
text = self.text[data_id]
|
||||||
path = self.path[data_id]
|
path = self.path[data_id]
|
||||||
if self.is_image(path):
|
if self.is_image(path):
|
||||||
|
if self.is_i2v:
|
||||||
|
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||||
video = self.load_image(path)
|
video = self.load_image(path)
|
||||||
else:
|
else:
|
||||||
video = self.load_video(path)
|
video = self.load_video(path)
|
||||||
data = {"text": text, "video": video, "path": path}
|
if self.is_i2v:
|
||||||
|
video, first_frame = video
|
||||||
|
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||||
|
else:
|
||||||
|
data = {"text": text, "video": video, "path": path}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -100,21 +115,35 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class LightningModelForDataProcess(pl.LightningModule):
|
class LightningModelForDataProcess(pl.LightningModule):
|
||||||
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
model_path = [text_encoder_path, vae_path]
|
||||||
|
if image_encoder_path is not None:
|
||||||
|
model_path.append(image_encoder_path)
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
model_manager.load_models([text_encoder_path, vae_path])
|
model_manager.load_models(model_path)
|
||||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
||||||
|
|
||||||
self.pipe.device = self.device
|
self.pipe.device = self.device
|
||||||
if video is not None:
|
if video is not None:
|
||||||
|
# prompt
|
||||||
prompt_emb = self.pipe.encode_prompt(text)
|
prompt_emb = self.pipe.encode_prompt(text)
|
||||||
|
# video
|
||||||
|
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||||
data = {"latents": latents, "prompt_emb": prompt_emb}
|
# image
|
||||||
|
if "first_frame" in batch:
|
||||||
|
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||||
|
_, _, num_frames, height, width = video.shape
|
||||||
|
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||||
|
else:
|
||||||
|
image_emb = {}
|
||||||
|
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||||
torch.save(data, path + ".tensors.pth")
|
torch.save(data, path + ".tensors.pth")
|
||||||
|
|
||||||
|
|
||||||
@@ -145,10 +174,21 @@ class TensorDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class LightningModelForTrain(pl.LightningModule):
|
class LightningModelForTrain(pl.LightningModule):
|
||||||
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dit_path,
|
||||||
|
learning_rate=1e-5,
|
||||||
|
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||||
|
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||||
|
pretrained_lora_path=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
model_manager.load_models([dit_path])
|
if os.path.isfile(dit_path):
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
else:
|
||||||
|
dit_path = dit_path.split(",")
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
|
||||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
@@ -167,6 +207,7 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
|
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
|
||||||
|
|
||||||
def freeze_parameters(self):
|
def freeze_parameters(self):
|
||||||
@@ -210,24 +251,30 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
# Data
|
# Data
|
||||||
latents = batch["latents"].to(self.device)
|
latents = batch["latents"].to(self.device)
|
||||||
prompt_emb = batch["prompt_emb"]
|
prompt_emb = batch["prompt_emb"]
|
||||||
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
|
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||||
|
image_emb = batch["image_emb"]
|
||||||
|
if "clip_feature" in image_emb:
|
||||||
|
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||||
|
if "y" in image_emb:
|
||||||
|
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
|
self.pipe.device = self.device
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
extra_input = self.pipe.prepare_extra_input(latents)
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
noise_pred = self.pipe.denoising_model()(
|
||||||
noise_pred = self.pipe.denoising_model()(
|
noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb,
|
||||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||||
)
|
)
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
# Record log
|
# Record log
|
||||||
self.log("train_loss", loss, prog_bar=True)
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
@@ -282,6 +329,12 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Path of text encoder.",
|
help="Path of text encoder.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of image encoder.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vae_path",
|
"--vae_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -410,6 +463,12 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use gradient checkpointing.",
|
help="Whether to use gradient checkpointing.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_gradient_checkpointing_offload",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use gradient checkpointing offload.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_architecture",
|
"--train_architecture",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -446,7 +505,8 @@ def data_process(args):
|
|||||||
frame_interval=1,
|
frame_interval=1,
|
||||||
num_frames=args.num_frames,
|
num_frames=args.num_frames,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
width=args.width
|
width=args.width,
|
||||||
|
is_i2v=args.image_encoder_path is not None
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -456,6 +516,7 @@ def data_process(args):
|
|||||||
)
|
)
|
||||||
model = LightningModelForDataProcess(
|
model = LightningModelForDataProcess(
|
||||||
text_encoder_path=args.text_encoder_path,
|
text_encoder_path=args.text_encoder_path,
|
||||||
|
image_encoder_path=args.image_encoder_path,
|
||||||
vae_path=args.vae_path,
|
vae_path=args.vae_path,
|
||||||
tiled=args.tiled,
|
tiled=args.tiled,
|
||||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||||
@@ -490,6 +551,7 @@ def train(args):
|
|||||||
lora_target_modules=args.lora_target_modules,
|
lora_target_modules=args.lora_target_modules,
|
||||||
init_lora_weights=args.init_lora_weights,
|
init_lora_weights=args.init_lora_weights,
|
||||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
pretrained_lora_path=args.pretrained_lora_path,
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
)
|
)
|
||||||
if args.use_swanlab:
|
if args.use_swanlab:
|
||||||
@@ -501,7 +563,7 @@ def train(args):
|
|||||||
name="wan",
|
name="wan",
|
||||||
config=swanlab_config,
|
config=swanlab_config,
|
||||||
mode=args.swanlab_mode,
|
mode=args.swanlab_mode,
|
||||||
logdir=args.output_path,
|
logdir=os.path.join(args.output_path, "swanlog"),
|
||||||
)
|
)
|
||||||
logger = [swanlab_logger]
|
logger = [swanlab_logger]
|
||||||
else:
|
else:
|
||||||
@@ -510,6 +572,7 @@ def train(args):
|
|||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
devices="auto",
|
devices="auto",
|
||||||
|
precision="bf16",
|
||||||
strategy=args.training_strategy,
|
strategy=args.training_strategy,
|
||||||
default_root_dir=args.output_path,
|
default_root_dir=args.output_path,
|
||||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
|||||||
41
examples/wanvideo/wan_1.3b_motion_controller.py
Normal file
41
examples/wanvideo/wan_1.3b_motion_controller.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
||||||
|
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||||
|
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||||
|
)
|
||||||
|
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
seed=1, tiled=True,
|
||||||
|
motion_bucket_id=0
|
||||||
|
)
|
||||||
|
save_video(video, "video_slow.mp4", fps=15, quality=5)
|
||||||
|
|
||||||
|
video = pipe(
|
||||||
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
seed=1, tiled=True,
|
||||||
|
motion_bucket_id=100
|
||||||
|
)
|
||||||
|
save_video(video, "video_fast.mp4", fps=15, quality=5)
|
||||||
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||||
|
)
|
||||||
|
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
seed=0, tiled=True,
|
||||||
|
# TeaCache parameters
|
||||||
|
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
|
||||||
|
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
|
|
||||||
|
# TeaCache doesn't support video-to-video
|
||||||
@@ -9,6 +9,10 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-
|
|||||||
|
|
||||||
# Load models
|
# Load models
|
||||||
model_manager = ModelManager(device="cpu")
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
|
||||||
|
torch_dtype=torch.float32, # Image Encoder is loaded with float32
|
||||||
|
)
|
||||||
model_manager.load_models(
|
model_manager.load_models(
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
@@ -20,14 +24,13 @@ model_manager.load_models(
|
|||||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
|
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
|
||||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
|
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
|
||||||
],
|
],
|
||||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
|
||||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
|
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
|
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
|
||||||
],
|
],
|
||||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||||
)
|
)
|
||||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||||
|
|
||||||
# Download example image
|
# Download example image
|
||||||
dataset_snapshot_download(
|
dataset_snapshot_download(
|
||||||
|
|||||||
149
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
149
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import torch
|
||||||
|
import lightning as pl
|
||||||
|
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput
|
||||||
|
from torch.distributed._tensor import Replicate, Shard
|
||||||
|
from torch.distributed.tensor.parallel import parallelize_module
|
||||||
|
from lightning.pytorch.strategies import ModelParallelStrategy
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video
|
||||||
|
from tqdm import tqdm
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ToyDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, tasks=[]):
|
||||||
|
self.tasks = tasks
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
return self.tasks[data_id]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.tasks)
|
||||||
|
|
||||||
|
|
||||||
|
class LitModel(pl.LightningModule):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||||
|
],
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
def configure_model(self):
|
||||||
|
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||||
|
plan = {
|
||||||
|
"text_embedding.0": ColwiseParallel(),
|
||||||
|
"text_embedding.2": RowwiseParallel(),
|
||||||
|
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
|
||||||
|
"text_embedding.0": ColwiseParallel(),
|
||||||
|
"text_embedding.2": RowwiseParallel(),
|
||||||
|
"blocks.0": PrepareModuleInput(
|
||||||
|
input_layouts=(Replicate(), None, None, None),
|
||||||
|
desired_input_layouts=(Replicate(), None, None, None),
|
||||||
|
),
|
||||||
|
"head": PrepareModuleInput(
|
||||||
|
input_layouts=(Replicate(), None),
|
||||||
|
desired_input_layouts=(Replicate(), None),
|
||||||
|
use_local_output=True,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
|
||||||
|
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||||
|
layer_tp_plan = {
|
||||||
|
"self_attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Replicate()),
|
||||||
|
desired_input_layouts=(Shard(1), Shard(0)),
|
||||||
|
),
|
||||||
|
"self_attn.q": SequenceParallel(),
|
||||||
|
"self_attn.k": SequenceParallel(),
|
||||||
|
"self_attn.v": SequenceParallel(),
|
||||||
|
"self_attn.norm_q": SequenceParallel(),
|
||||||
|
"self_attn.norm_k": SequenceParallel(),
|
||||||
|
"self_attn.attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
|
),
|
||||||
|
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()),
|
||||||
|
|
||||||
|
"cross_attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Replicate()),
|
||||||
|
desired_input_layouts=(Shard(1), Replicate()),
|
||||||
|
),
|
||||||
|
"cross_attn.q": SequenceParallel(),
|
||||||
|
"cross_attn.k": SequenceParallel(),
|
||||||
|
"cross_attn.v": SequenceParallel(),
|
||||||
|
"cross_attn.norm_q": SequenceParallel(),
|
||||||
|
"cross_attn.norm_k": SequenceParallel(),
|
||||||
|
"cross_attn.attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
|
),
|
||||||
|
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False),
|
||||||
|
|
||||||
|
"ffn.0": ColwiseParallel(input_layouts=Shard(1)),
|
||||||
|
"ffn.2": RowwiseParallel(output_layouts=Replicate()),
|
||||||
|
|
||||||
|
"norm1": SequenceParallel(use_local_output=True),
|
||||||
|
"norm2": SequenceParallel(use_local_output=True),
|
||||||
|
"norm3": SequenceParallel(use_local_output=True),
|
||||||
|
"gate": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Replicate(), Replicate()),
|
||||||
|
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
parallelize_module(
|
||||||
|
module=block,
|
||||||
|
device_mesh=tp_mesh,
|
||||||
|
parallelize_plan=layer_tp_plan,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_step(self, batch):
|
||||||
|
data = batch[0]
|
||||||
|
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
||||||
|
output_path = data.pop("output_path")
|
||||||
|
with torch.no_grad(), torch.inference_mode(False):
|
||||||
|
video = self.pipe(**data)
|
||||||
|
if self.local_rank == 0:
|
||||||
|
save_video(video, output_path, fps=15, quality=5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
ToyDataset([
|
||||||
|
{
|
||||||
|
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
"num_inference_steps": 50,
|
||||||
|
"seed": 0,
|
||||||
|
"tiled": False,
|
||||||
|
"output_path": "video1.mp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
"num_inference_steps": 50,
|
||||||
|
"seed": 1,
|
||||||
|
"tiled": False,
|
||||||
|
"output_path": "video2.mp4",
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
collate_fn=lambda x: x
|
||||||
|
)
|
||||||
|
model = LitModel()
|
||||||
|
trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy())
|
||||||
|
trainer.test(model, dataloader)
|
||||||
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||||
|
],
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||||
|
)
|
||||||
|
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
init_method="env://",
|
||||||
|
)
|
||||||
|
from xfuser.core.distributed import (initialize_model_parallel,
|
||||||
|
init_distributed_environment)
|
||||||
|
init_distributed_environment(
|
||||||
|
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
|
||||||
|
initialize_model_parallel(
|
||||||
|
sequence_parallel_degree=dist.get_world_size(),
|
||||||
|
ring_degree=1,
|
||||||
|
ulysses_degree=dist.get_world_size(),
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device=f"cuda:{dist.get_rank()}",
|
||||||
|
use_usp=True if dist.get_world_size() > 1 else False)
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
seed=0, tiled=True
|
||||||
|
)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
save_video(video, "video1.mp4", fps=25, quality=5)
|
||||||
42
examples/wanvideo/wan_fun_InP.py
Normal file
42
examples/wanvideo/wan_fun_InP.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||||
|
from modelscope import snapshot_download, dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||||
|
)
|
||||||
|
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||||
|
|
||||||
|
# Download example image
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||||
|
)
|
||||||
|
image = Image.open("data/examples/wan/input_image.jpg")
|
||||||
|
|
||||||
|
# Image-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
input_image=image,
|
||||||
|
# You can input `end_image=xxx` to control the last frame of the video.
|
||||||
|
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||||
|
seed=1, tiled=True
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
40
examples/wanvideo/wan_fun_control.py
Normal file
40
examples/wanvideo/wan_fun_control.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||||
|
from modelscope import snapshot_download, dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
|
||||||
|
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||||
|
)
|
||||||
|
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||||
|
|
||||||
|
# Download example video
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Control-to-video
|
||||||
|
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||||
|
video = pipe(
|
||||||
|
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
control_video=control_video, height=832, width=576, num_frames=49,
|
||||||
|
seed=1, tiled=True
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
2
setup.py
2
setup.py
@@ -14,7 +14,7 @@ else:
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffsynth",
|
name="diffsynth",
|
||||||
version="1.1.2",
|
version="1.1.7",
|
||||||
description="Enjoy the magic of Diffusion models!",
|
description="Enjoy the magic of Diffusion models!",
|
||||||
author="Artiprocher",
|
author="Artiprocher",
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
|
|||||||
241
train_flux_reference.py
Normal file
241
train_flux_reference.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
from diffsynth import ModelManager, FluxImagePipeline
|
||||||
|
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
|
||||||
|
from diffsynth.models.lora import FluxLoRAConverter
|
||||||
|
import torch, os, argparse
|
||||||
|
import lightning as pl
|
||||||
|
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||||
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||||
|
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModel(LightningModelForT2ILoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
|
||||||
|
learning_rate=1e-4, use_gradient_checkpointing=True,
|
||||||
|
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
|
||||||
|
state_dict_converter=None, quantize = None
|
||||||
|
):
|
||||||
|
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||||
|
if quantize is None:
|
||||||
|
model_manager.load_models(pretrained_weights)
|
||||||
|
else:
|
||||||
|
model_manager.load_models(pretrained_weights[1:])
|
||||||
|
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||||
|
if preset_lora_path is not None:
|
||||||
|
preset_lora_path = preset_lora_path.split(",")
|
||||||
|
for path in preset_lora_path:
|
||||||
|
model_manager.load_lora(path)
|
||||||
|
|
||||||
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||||
|
self.pipe.reference_embedder.init()
|
||||||
|
|
||||||
|
if quantize is not None:
|
||||||
|
self.pipe.dit.quantize()
|
||||||
|
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
self.freeze_parameters()
|
||||||
|
self.pipe.reference_embedder.requires_grad_(True)
|
||||||
|
self.pipe.reference_embedder.train()
|
||||||
|
self.pipe.dit.requires_grad_(True)
|
||||||
|
self.pipe.dit.train()
|
||||||
|
# self.add_lora_to_model(
|
||||||
|
# self.pipe.denoising_model(),
|
||||||
|
# lora_rank=lora_rank,
|
||||||
|
# lora_alpha=lora_alpha,
|
||||||
|
# lora_target_modules=lora_target_modules,
|
||||||
|
# init_lora_weights=init_lora_weights,
|
||||||
|
# pretrained_lora_path=pretrained_lora_path,
|
||||||
|
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Data
|
||||||
|
text, image = batch["instruction"], batch["image_2"]
|
||||||
|
image_ref = batch["image_1"]
|
||||||
|
|
||||||
|
# Prepare input parameters
|
||||||
|
self.pipe.device = self.device
|
||||||
|
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||||
|
if "latents" in batch:
|
||||||
|
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||||
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
noise_pred = lets_dance_flux(
|
||||||
|
self.pipe.denoising_model(),
|
||||||
|
reference_embedder=self.pipe.reference_embedder,
|
||||||
|
hidden_states_ref=hidden_states_ref,
|
||||||
|
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||||
|
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||||
|
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
checkpoint.clear()
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
state_dict = self.pipe.state_dict()
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in trainable_param_names:
|
||||||
|
lora_state_dict[name] = param
|
||||||
|
if self.state_dict_converter is not None:
|
||||||
|
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||||
|
checkpoint.update(lora_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_2_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_dit_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_vae_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_target_modules",
|
||||||
|
type=str,
|
||||||
|
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
|
||||||
|
help="Layers with LoRA modules.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--align_to_opensource_format",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to export lora files aligned with other opensource format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["float8_e4m3fn"],
|
||||||
|
help="Whether to use quantization when training the model, and in which format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preset_lora_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Preset LoRA path.",
|
||||||
|
)
|
||||||
|
parser = add_general_parsers(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
model = LightningModel(
|
||||||
|
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
|
||||||
|
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
init_lora_weights=args.init_lora_weights,
|
||||||
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
|
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
|
||||||
|
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
|
||||||
|
)
|
||||||
|
# dataset and data loader
|
||||||
|
dataset = MultiTaskDataset(
|
||||||
|
dataset_list=[
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||||
|
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dataset_weight=(4, 1, 4, 1),
|
||||||
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
)
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.dataloader_num_workers
|
||||||
|
)
|
||||||
|
# train
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
precision=args.precision,
|
||||||
|
strategy=args.training_strategy,
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
|
logger=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
248
train_flux_reference_multi_node.py
Normal file
248
train_flux_reference_multi_node.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
from diffsynth import ModelManager, FluxImagePipeline
|
||||||
|
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
|
||||||
|
from diffsynth.models.lora import FluxLoRAConverter
|
||||||
|
import torch, os, argparse
|
||||||
|
import lightning as pl
|
||||||
|
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||||
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||||
|
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModel(LightningModelForT2ILoRA):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
|
||||||
|
learning_rate=1e-4, use_gradient_checkpointing=True,
|
||||||
|
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
|
||||||
|
state_dict_converter=None, quantize = None
|
||||||
|
):
|
||||||
|
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||||
|
if quantize is None:
|
||||||
|
model_manager.load_models(pretrained_weights)
|
||||||
|
else:
|
||||||
|
model_manager.load_models(pretrained_weights[1:])
|
||||||
|
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
|
||||||
|
if preset_lora_path is not None:
|
||||||
|
preset_lora_path = preset_lora_path.split(",")
|
||||||
|
for path in preset_lora_path:
|
||||||
|
model_manager.load_lora(path)
|
||||||
|
|
||||||
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||||
|
self.pipe.reference_embedder.init()
|
||||||
|
|
||||||
|
if quantize is not None:
|
||||||
|
self.pipe.dit.quantize()
|
||||||
|
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
self.freeze_parameters()
|
||||||
|
self.pipe.reference_embedder.requires_grad_(True)
|
||||||
|
self.pipe.reference_embedder.train()
|
||||||
|
self.pipe.dit.requires_grad_(True)
|
||||||
|
self.pipe.dit.train()
|
||||||
|
# self.add_lora_to_model(
|
||||||
|
# self.pipe.denoising_model(),
|
||||||
|
# lora_rank=lora_rank,
|
||||||
|
# lora_alpha=lora_alpha,
|
||||||
|
# lora_target_modules=lora_target_modules,
|
||||||
|
# init_lora_weights=init_lora_weights,
|
||||||
|
# pretrained_lora_path=pretrained_lora_path,
|
||||||
|
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Data
|
||||||
|
text, image = batch["instruction"], batch["image_2"]
|
||||||
|
image_ref = batch["image_1"]
|
||||||
|
|
||||||
|
# Prepare input parameters
|
||||||
|
self.pipe.device = self.device
|
||||||
|
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||||
|
if "latents" in batch:
|
||||||
|
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||||
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
|
||||||
|
# Reference image
|
||||||
|
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
noise_pred = lets_dance_flux(
|
||||||
|
self.pipe.denoising_model(),
|
||||||
|
reference_embedder=self.pipe.reference_embedder,
|
||||||
|
hidden_states_ref=hidden_states_ref,
|
||||||
|
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||||
|
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||||
|
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
checkpoint.clear()
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
state_dict = self.pipe.state_dict()
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in trainable_param_names:
|
||||||
|
lora_state_dict[name] = param
|
||||||
|
if self.state_dict_converter is not None:
|
||||||
|
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||||
|
checkpoint.update(lora_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_text_encoder_2_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_dit_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_vae_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_target_modules",
|
||||||
|
type=str,
|
||||||
|
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
|
||||||
|
help="Layers with LoRA modules.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--align_to_opensource_format",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to export lora files aligned with other opensource format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["float8_e4m3fn"],
|
||||||
|
help="Whether to use quantization when training the model, and in which format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preset_lora_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Preset LoRA path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_nodes",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Num nodes.",
|
||||||
|
)
|
||||||
|
parser = add_general_parsers(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
model = LightningModel(
|
||||||
|
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
|
||||||
|
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
init_lora_weights=args.init_lora_weights,
|
||||||
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
|
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
|
||||||
|
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
|
||||||
|
)
|
||||||
|
# dataset and data loader
|
||||||
|
dataset = MultiTaskDataset(
|
||||||
|
dataset_list=[
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||||
|
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
SingleTaskDataset(
|
||||||
|
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||||
|
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||||
|
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
|
||||||
|
height=512, width=512,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dataset_weight=(4, 1, 4, 1),
|
||||||
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
)
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.dataloader_num_workers
|
||||||
|
)
|
||||||
|
# train
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
num_nodes=args.num_nodes,
|
||||||
|
precision=args.precision,
|
||||||
|
strategy="ddp",
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
|
logger=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
Reference in New Issue
Block a user