mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
3dc28f428f | ||
|
|
3c8a3fe2e1 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
0dbb3d333f |
21
README.md
21
README.md
@@ -13,13 +13,19 @@ Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
||||
|
||||
## Introduction
|
||||
|
||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||
Welcome to the magic world of Diffusion models!
|
||||
|
||||
Until now, DiffSynth Studio has supported the following models:
|
||||
DiffSynth consists of two open-source projects:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||
|
||||
Until now, DiffSynth-Studio has supported the following models:
|
||||
|
||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
@@ -36,6 +42,11 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
|
||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
@@ -43,7 +54,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
@@ -72,7 +83,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
||||
- LoRA, ControlNet, and additional models will be available soon.
|
||||
|
||||
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
|
||||
@@ -37,6 +37,7 @@ from ..models.flux_text_encoder import FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
|
||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from ..models.cog_dit import CogDiT
|
||||
@@ -95,6 +96,7 @@ model_loader_configs = [
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
@@ -103,6 +105,8 @@ model_loader_configs = [
|
||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||
@@ -116,6 +120,7 @@ model_loader_configs = [
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
@@ -133,6 +138,7 @@ huggingface_model_loader_configs = [
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
@@ -595,6 +601,25 @@ preset_models_on_modelscope = {
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
"InfiniteYou":{
|
||||
"file_list":[
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
],
|
||||
"load_path":[
|
||||
[
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||
],
|
||||
"models/InfiniteYou/image_proj_model.bin",
|
||||
],
|
||||
},
|
||||
# ESRGAN
|
||||
"ESRGAN_x4": [
|
||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||
@@ -675,6 +700,25 @@ preset_models_on_modelscope = {
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideoI2V":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideo-fp8":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
@@ -735,6 +779,7 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||
"InfiniteYou",
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||
"QwenPrompt",
|
||||
"OmostPrompt",
|
||||
@@ -751,4 +796,5 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"StableDiffusion3.5-medium",
|
||||
"HunyuanVideo",
|
||||
"HunyuanVideo-fp8",
|
||||
"HunyuanVideoI2V",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
@@ -15,18 +9,25 @@ class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||
if not skip_processor:
|
||||
if processor_id == "canny":
|
||||
from controlnet_aux.processor import CannyDetector
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
from controlnet_aux.processor import MidasDetector
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
from controlnet_aux.processor import HEDdetector
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
from controlnet_aux.processor import LineartDetector
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
from controlnet_aux.processor import LineartAnimeDetector
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
from controlnet_aux.processor import OpenposeDetector
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "normal":
|
||||
from controlnet_aux.processor import NormalBaeDetector
|
||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||
self.processor = None
|
||||
|
||||
0
diffsynth/distributed/__init__.py
Normal file
0
diffsynth/distributed/__init__.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from einops import rearrange
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
padding_tensor = torch.ones(
|
||||
pad_size,
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
return padded_tensor
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
s_per_rank = x.shape[1]
|
||||
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||
|
||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
def usp_dit_forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
# Context Parallel
|
||||
x = torch.chunk(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
x = self.head(x, t)
|
||||
|
||||
# Context Parallel
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def usp_attn_forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
|
||||
x = xFuserLongContextAttention()(
|
||||
None,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
)
|
||||
x = x.flatten(2)
|
||||
|
||||
del q, k, v
|
||||
torch.cuda.empty_cache()
|
||||
return self.o(x)
|
||||
@@ -5,7 +5,7 @@ import pathlib
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from turtle import forward
|
||||
# from turtle import forward
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -318,6 +318,8 @@ class FluxControlNetStateDictConverter:
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
return state_dict_, extra_kwargs
|
||||
|
||||
@@ -628,19 +628,22 @@ class FluxDiTStateDictConverter:
|
||||
else:
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
|
||||
128
diffsynth/models/flux_infiniteyou.py
Normal file
128
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class InfiniteYouImageProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=1280,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=20,
|
||||
num_queries=8,
|
||||
embedding_dim=512,
|
||||
output_dim=4096,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||
|
||||
|
||||
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict['image_proj']
|
||||
@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
@@ -236,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
@@ -269,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
@@ -279,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
|
||||
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||
if tr_shift is not None:
|
||||
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = torch.concat((x_zero, x_orig), dim=1)
|
||||
return x
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
@@ -290,7 +296,7 @@ def modulate(x, shift=None, scale=None):
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
@@ -343,7 +349,7 @@ def rotate_half(x):
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
@@ -385,6 +391,15 @@ def attention(q, k, v):
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||
if tr_gate is not None:
|
||||
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||
return torch.concat((x_zero, x_orig), dim=1)
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||
else:
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
@@ -418,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod):
|
||||
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
||||
if mod_tr is not None:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||
else:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
@@ -488,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||
else:
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
||||
v_len = txt_len - split_token
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
@@ -580,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
@@ -610,10 +642,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||
if self.guidance_in is not None:
|
||||
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
@@ -625,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
@@ -681,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
@@ -689,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
@@ -711,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
# TODO: implement the low VRAM inference for MLLM.
|
||||
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||
outputs = super().forward(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
pixel_values=pixel_values)
|
||||
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
return hidden_state
|
||||
|
||||
@@ -195,70 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
"txt.mod": "txt_mod",
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||
device, torch_dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, torch_dtype = param.device, param.dtype
|
||||
break
|
||||
return device, torch_dtype
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
target_name = ".".join(keys)
|
||||
if target_name not in target_state_dict:
|
||||
return {}
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||
if matched_num == len(lora_name_dict):
|
||||
return "", ""
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_device_and_dtype(self, state_dict):
|
||||
device, dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, dtype = param.device, param.dtype
|
||||
break
|
||||
computation_device = device
|
||||
computation_dtype = dtype
|
||||
if computation_device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
computation_device = torch.device("cuda")
|
||||
if computation_dtype == torch.float8_e4m3fn:
|
||||
computation_dtype = torch.float32
|
||||
return device, dtype, computation_device, computation_dtype
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
model.load_state_dict(state_dict_model)
|
||||
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_patched = weight_model + weight_lora
|
||||
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for model_class in self.supported_model_classes:
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora_) > 0:
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
@@ -362,7 +365,22 @@ class FluxLoRAConverter:
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class WanLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -376,6 +376,7 @@ class ModelManager:
|
||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||
else:
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
is_loaded = False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
@@ -385,7 +386,10 @@ class ModelManager:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
is_loaded = True
|
||||
break
|
||||
if not is_loaded:
|
||||
print(f" Cannot load LoRA: {file_path}")
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
return super().forward(x).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# compute attention
|
||||
p = self.attn_dropout if self.training else 0.0
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||
x = x.reshape(b, s, c)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, version=2)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||
x = x.reshape(b, 1, c)
|
||||
|
||||
# output
|
||||
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||
videos = videos.to(dtype)
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
return out
|
||||
|
||||
|
||||
@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float().clamp_(-1, 1)
|
||||
values = values.clamp_(-1, 1)
|
||||
return values
|
||||
|
||||
|
||||
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float()
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, self.scale)
|
||||
return x.float()
|
||||
return x
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, self.scale)
|
||||
return video.float().clamp_(-1, 1)
|
||||
return video.clamp_(-1, 1)
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
|
||||
@@ -31,6 +31,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.controlnet: FluxMultiControlNetManager = None
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
@@ -162,6 +163,11 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
|
||||
# InfiniteYou
|
||||
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||
if self.image_proj_model is not None:
|
||||
self.infinityou_processor = InfinitYou(device=self.device)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
@@ -347,6 +353,13 @@ class FluxImagePipeline(BasePipeline):
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
|
||||
|
||||
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||
if self.infinityou_processor is not None and id_image is not None:
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -382,6 +395,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
eligen_entity_masks=None,
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=False,
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -409,6 +425,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# InfiniteYou
|
||||
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
|
||||
# Entity control
|
||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||
|
||||
@@ -430,7 +449,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -447,7 +466,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -467,6 +486,58 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class InfinitYou:
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
from facexlib.recognition import init_recognition_model
|
||||
from insightface.app import FaceAnalysis
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
insightface_root_path = 'models/InfiniteYou/insightface'
|
||||
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
|
||||
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
||||
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
||||
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
||||
|
||||
def _detect_face(self, id_image_cv2):
|
||||
face_info = self.app_640.get(id_image_cv2)
|
||||
if len(face_info) > 0:
|
||||
return face_info
|
||||
face_info = self.app_320.get(id_image_cv2)
|
||||
if len(face_info) > 0:
|
||||
return face_info
|
||||
face_info = self.app_160.get(id_image_cv2)
|
||||
return face_info
|
||||
|
||||
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
||||
from insightface.utils import face_align
|
||||
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
||||
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||
arc_face_image = 2 * arc_face_image - 1
|
||||
arc_face_image = arc_face_image.contiguous().to(self.device)
|
||||
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
||||
return face_emb
|
||||
|
||||
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||
import cv2
|
||||
if id_image is None:
|
||||
return {'id_emb': None}, controlnet_image
|
||||
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
|
||||
face_info = self._detect_face(id_image_cv2)
|
||||
if len(face_info) == 0:
|
||||
raise ValueError('No face detected in the input ID image')
|
||||
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
||||
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
||||
if controlnet_image is None:
|
||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
||||
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
|
||||
|
||||
|
||||
class TeaCache:
|
||||
@@ -529,6 +600,8 @@ def lets_dance_flux(
|
||||
entity_prompt_emb=None,
|
||||
entity_masks=None,
|
||||
ipadapter_kwargs_list={},
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -573,6 +646,9 @@ def lets_dance_flux(
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride,
|
||||
}
|
||||
if id_emb is not None:
|
||||
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
pipe.enable_vram_management()
|
||||
return pipe
|
||||
|
||||
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
|
||||
num_patches = round((base_size / patch_size)**2)
|
||||
assert max_ratio >= 1.0
|
||||
crop_size_list = []
|
||||
wp, hp = num_patches, 1
|
||||
while wp > 0:
|
||||
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
||||
crop_size_list.append((wp * patch_size, hp * patch_size))
|
||||
if (hp + 1) * wp <= num_patches:
|
||||
hp += 1
|
||||
else:
|
||||
wp -= 1
|
||||
return crop_size_list
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||
|
||||
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
|
||||
aspect_ratio = float(height) / float(width)
|
||||
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
|
||||
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||
return buckets[closest_ratio_id], float(closest_ratio)
|
||||
|
||||
|
||||
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
|
||||
if i2v_resolution == "720p":
|
||||
bucket_hw_base_size = 960
|
||||
elif i2v_resolution == "540p":
|
||||
bucket_hw_base_size = 720
|
||||
elif i2v_resolution == "360p":
|
||||
bucket_hw_base_size = 480
|
||||
else:
|
||||
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
||||
origin_size = semantic_images[0].size
|
||||
|
||||
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
|
||||
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
|
||||
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
||||
ref_image_transform = transforms.Compose([
|
||||
transforms.Resize(closest_size),
|
||||
transforms.CenterCrop(closest_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])
|
||||
])
|
||||
|
||||
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
||||
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
||||
target_height, target_width = closest_size
|
||||
return semantic_image_pixel_values, target_height, target_width
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
|
||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||
|
||||
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
input_images=None,
|
||||
i2v_resolution="720p",
|
||||
i2v_stability=True,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device=None,
|
||||
@@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# encoder input images
|
||||
if input_images is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
|
||||
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
|
||||
image_latents = self.vae_encoder(image_pixel_values)
|
||||
|
||||
# Initialize noise
|
||||
rand_device = self.device if rand_device is None else rand_device
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
elif input_images is not None and i2v_stability:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
|
||||
t = torch.tensor([0.999]).to(device=self.device)
|
||||
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
|
||||
latents = latents.to(dtype=image_latents.dtype)
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
# current mllm does not support vram_management
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
forward_func = lets_dance_hunyuan_video
|
||||
if input_images is not None:
|
||||
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
|
||||
forward_func = lets_dance_hunyuan_video_i2v
|
||||
|
||||
# Inference
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
if input_images is not None:
|
||||
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
|
||||
latents = torch.concat([image_latents, latents], dim=2)
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
@@ -194,7 +267,7 @@ class TeaCache:
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
else:
|
||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
@@ -203,14 +276,14 @@ class TeaCache:
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = img.clone()
|
||||
return not should_calc
|
||||
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
img = x[:, :-256]
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
img = dit.final_layer(img, vec)
|
||||
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def lets_dance_hunyuan_video_i2v(
|
||||
dit: HunyuanVideoDiT,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
# Uncomment below to keep same as official implementation
|
||||
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
|
||||
vec = dit.time_in(t, dtype=torch.bfloat16)
|
||||
vec_2 = dit.vector_in(pooled_prompt_emb)
|
||||
vec = vec + vec_2
|
||||
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
|
||||
|
||||
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
|
||||
tr_token = (H // 2) * (W // 2)
|
||||
token_replace_vec = token_replace_vec + vec_2
|
||||
|
||||
img = dit.img_in(x)
|
||||
txt = dit.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, img, vec)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if tea_cache_update:
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import types
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
@@ -11,10 +12,11 @@ from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
|
||||
|
||||
@@ -29,9 +31,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
@@ -60,8 +63,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
WanLayerNorm: AutoWrappedModule,
|
||||
WanRMSNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -116,7 +118,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
@@ -135,11 +137,20 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
||||
|
||||
for block in pipe.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
||||
pipe.sp_size = get_sequence_parallel_world_size()
|
||||
pipe.use_unified_sequence_parallel = True
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -148,22 +159,26 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
def encode_image(self, image, num_frames, height, width):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
return {"clip_fea": clip_context, "y": [y]}
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
@@ -174,19 +189,21 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
||||
return {}
|
||||
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return frames
|
||||
|
||||
|
||||
def prepare_unified_sequence_parallel(self):
|
||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -208,6 +225,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="",
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
@@ -221,15 +240,16 @@ class WanVideoPipeline(BasePipeline):
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Initialize noise
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
if input_video is not None:
|
||||
self.load_models_to_device(['vae'])
|
||||
input_video = self.preprocess_images(input_video)
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
||||
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise
|
||||
@@ -249,23 +269,29 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Unified Sequence Parallel
|
||||
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -274,3 +300,117 @@ class WanVideoPipeline(BasePipeline):
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
self.coefficients_dict = {
|
||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||
}
|
||||
if model_id not in self.coefficients_dict:
|
||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||
self.coefficients = self.coefficients_dict[model_id]
|
||||
|
||||
def check(self, dit: WanModel, x, t_mod):
|
||||
modulated_inp = t_mod.clone()
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = self.coefficients
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = x.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
if dit.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
for block in dit.blocks:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
|
||||
import os, torch
|
||||
from typing import Union
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE = {
|
||||
"dit-llm-encode": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE,
|
||||
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||
"crop_start": 95,
|
||||
},
|
||||
"dit-llm-encode-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_I2V,
|
||||
"crop_start": 36,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
"dit-llm-encode-video-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||
"crop_start": 103,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
}
|
||||
|
||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||
@@ -56,9 +91,20 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
def fetch_models(self,
|
||||
text_encoder_1: SD3TextEncoder1 = None,
|
||||
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
|
||||
# processor
|
||||
# TODO: may need to replace processor with local implementation
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
|
||||
# template
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
|
||||
|
||||
def apply_text_to_template(self, text, template):
|
||||
assert isinstance(template, str)
|
||||
@@ -107,8 +153,89 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt_using_mllm(self,
|
||||
prompt,
|
||||
images,
|
||||
max_length,
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True,
|
||||
image_embed_interleave=4):
|
||||
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
hidden_state_skip_layer=hidden_state_skip_layer,
|
||||
pixel_values=image_outputs)
|
||||
|
||||
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
|
||||
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
|
||||
batch_indices, last_double_return_token_indices = torch.where(
|
||||
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
|
||||
if last_double_return_token_indices.shape[0] == 3:
|
||||
# in case the prompt is too long
|
||||
last_double_return_token_indices = torch.cat((
|
||||
last_double_return_token_indices,
|
||||
torch.tensor([input_ids.shape[-1]]),
|
||||
))
|
||||
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
|
||||
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
|
||||
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
|
||||
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
|
||||
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
|
||||
attention_mask_assistant_crop_end = last_double_return_token_indices
|
||||
text_last_hidden_state = []
|
||||
text_attention_mask = []
|
||||
image_last_hidden_state = []
|
||||
image_attention_mask = []
|
||||
for i in range(input_ids.shape[0]):
|
||||
text_last_hidden_state.append(
|
||||
torch.cat([
|
||||
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
|
||||
last_hidden_state[i, assistant_crop_end[i].item():],
|
||||
]))
|
||||
text_attention_mask.append(
|
||||
torch.cat([
|
||||
attention_mask[
|
||||
i,
|
||||
crop_start:attention_mask_assistant_crop_start[i].item(),
|
||||
],
|
||||
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
|
||||
]) if use_attention_mask else None)
|
||||
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
|
||||
image_attention_mask.append(
|
||||
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
|
||||
to(attention_mask.dtype) if use_attention_mask else None)
|
||||
|
||||
text_last_hidden_state = torch.stack(text_last_hidden_state)
|
||||
text_attention_mask = torch.stack(text_attention_mask)
|
||||
image_last_hidden_state = torch.stack(image_last_hidden_state)
|
||||
image_attention_mask = torch.stack(image_attention_mask)
|
||||
|
||||
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
|
||||
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
|
||||
|
||||
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
|
||||
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
|
||||
|
||||
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
|
||||
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
images=None,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
clip_sequence_length=77,
|
||||
@@ -116,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True):
|
||||
use_attention_mask=True,
|
||||
image_embed_interleave=4):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
@@ -136,8 +264,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
if images is None:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
else:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
|
||||
crop_start, hidden_state_skip_layer, use_attention_mask,
|
||||
image_embed_interleave)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_emb = self.text_encoder(ids, mask)
|
||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
||||
for i, v in enumerate(seq_lens):
|
||||
prompt_emb[:, v:] = 0
|
||||
return prompt_emb
|
||||
|
||||
@@ -37,7 +37,7 @@ class FlowMatchScheduler():
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
{
|
||||
"_valid_processor_keys": [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"resample",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format"
|
||||
],
|
||||
"crop_size": {
|
||||
"height": 336,
|
||||
"width": 336
|
||||
},
|
||||
"do_center_crop": true,
|
||||
"do_convert_rgb": true,
|
||||
"do_normalize": true,
|
||||
"do_rescale": true,
|
||||
"do_resize": true,
|
||||
"image_mean": [
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073
|
||||
],
|
||||
"image_processor_type": "CLIPImageProcessor",
|
||||
"image_std": [
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711
|
||||
],
|
||||
"processor_class": "LlavaProcessor",
|
||||
"resample": 3,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"size": {
|
||||
"shortest_edge": 336
|
||||
}
|
||||
}
|
||||
@@ -290,7 +290,7 @@ def launch_training_task(model, args):
|
||||
name="diffsynth_studio",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=args.output_path,
|
||||
logdir=os.path.join(args.output_path, "swanlog"),
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
|
||||
@@ -6,7 +6,7 @@ We propose EliGen, a novel approach that leverages fine-grained entity-level inf
|
||||
|
||||
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
@@ -77,6 +77,11 @@ Demonstration of the styled entity control results with EliGen and IP-Adapter, s
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
We also provide a demo of the styled entity control results with EliGen and specific styled lora, see [./styled_entity_control.py](./styled_entity_control.py) for details. Here is the visualization of EliGen with [Lego dreambooth lora](https://huggingface.co/merve/flux-lego-lora-dreambooth).
|
||||
|||||
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
### Entity Transfer
|
||||
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.
|
||||
|
||||
|
||||
@@ -27,11 +27,20 @@ def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
|
||||
# download and load model
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
# set download_from_modelscope = False if you want to download model from huggingface
|
||||
download_from_modelscope = True
|
||||
if download_from_modelscope:
|
||||
model_id = "DiffSynth-Studio/Eligen"
|
||||
downloading_priority = ["ModelScope"]
|
||||
else:
|
||||
model_id = "modelscope/EliGen"
|
||||
downloading_priority = ["HuggingFace"]
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
model_id=model_id,
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
local_dir="models/lora/entity_control",
|
||||
downloading_priority=downloading_priority
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
|
||||
90
examples/EntityControl/styled_entity_control.py
Normal file
90
examples/EntityControl/styled_entity_control.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
from examples.EntityControl.utils import visualize_masks
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"styled_eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
# download and load model
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="FluxLora/merve-flux-lego-lora-dreambooth",
|
||||
origin_file_path="pytorch_lora_weights.safetensors",
|
||||
local_dir="models/lora/merve-flux-lego-lora-dreambooth"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# example 1
|
||||
trigger_word = "lego set in style of TOK, "
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 6
|
||||
global_prompt = "Snow White and the 6 Dwarfs."
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -8,6 +8,12 @@
|
||||
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
||||
|
||||
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|
||||
|VRAM required|Example script|Frames|Resolution|Note|
|
||||
|-|-|-|-|-|
|
||||
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|
||||
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|
||||
## Gallery
|
||||
|
||||
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
||||
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
||||
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||
|
||||
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
||||
|
||||
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
|
||||
|
||||
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a
|
||||
|
||||
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
download_models(["HunyuanVideoI2V"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=True)
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||
|
||||
i2v_resolution = "720p"
|
||||
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)
|
||||
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
download_models(["HunyuanVideoI2V"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda"
|
||||
)
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=False)
|
||||
# Although you have enough VRAM, we still recommend you to enable offload.
|
||||
pipe.enable_cpu_offload()
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||
|
||||
i2v_resolution = "720p"
|
||||
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)
|
||||
7
examples/InfiniteYou/README.md
Normal file
7
examples/InfiniteYou/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
|
||||
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
|
||||
|
||||
|Identity Image|Generated Image|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
58
examples/InfiniteYou/infiniteyou.py
Normal file
58
examples/InfiniteYou/infiniteyou.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import importlib
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_models, ControlNetConfigUnit
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
if importlib.util.find_spec("facexlib") is None:
|
||||
raise ImportError("You are using InifiniteYou. It depends on facexlib, which is not installed. Please install it with `pip install facexlib`.")
|
||||
if importlib.util.find_spec("insightface") is None:
|
||||
raise ImportError("You are using InifiniteYou. It depends on insightface, which is not installed. Please install it with `pip install insightface`.")
|
||||
|
||||
download_models(["InfiniteYou"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_models([
|
||||
[
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||
],
|
||||
"models/InfiniteYou/image_proj_model.bin",
|
||||
])
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_model_manager(
|
||||
model_manager,
|
||||
controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="none",
|
||||
model_path=[
|
||||
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors',
|
||||
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors'
|
||||
],
|
||||
scale=1.0
|
||||
)
|
||||
]
|
||||
)
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/infiniteyou/*")
|
||||
|
||||
prompt = "A man, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/man.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("man.jpg")
|
||||
|
||||
prompt = "A woman, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("woman.jpg")
|
||||
@@ -31,6 +31,8 @@ Put sunglasses on the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
|
||||
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
|
||||
|
||||
### Wan-Video-14B-T2V
|
||||
|
||||
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
@@ -47,6 +49,22 @@ We present a detailed table here. The model is tested on a single A100.
|
||||
|
||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||
|
||||
### Parallel Inference
|
||||
|
||||
1. Unified Sequence Parallel (USP)
|
||||
|
||||
```bash
|
||||
pip install xfuser>=0.4.3
|
||||
```
|
||||
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
||||
```
|
||||
|
||||
2. Tensor Parallel
|
||||
|
||||
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
||||
|
||||
### Wan-Video-14B-I2V
|
||||
|
||||
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
||||
@@ -155,6 +173,12 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
||||
--use_gradient_checkpointing
|
||||
```
|
||||
|
||||
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
|
||||
|
||||
If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`.
|
||||
|
||||
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
|
||||
|
||||
Step 5: Test
|
||||
|
||||
Test LoRA:
|
||||
|
||||
@@ -7,11 +7,12 @@ from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
@@ -21,6 +22,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
self.num_frames = num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.is_i2v = is_i2v
|
||||
|
||||
self.frame_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
@@ -48,10 +50,13 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
return None
|
||||
|
||||
frames = []
|
||||
first_frame = None
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame)
|
||||
if first_frame is None:
|
||||
first_frame = np.array(frame)
|
||||
frame = frame_process(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
@@ -59,7 +64,10 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
frames = torch.stack(frames, dim=0)
|
||||
frames = rearrange(frames, "T C H W -> C T H W")
|
||||
|
||||
return frames
|
||||
if self.is_i2v:
|
||||
return frames, first_frame
|
||||
else:
|
||||
return frames
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
@@ -70,7 +78,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
|
||||
def is_image(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
if file_ext_name.lower() in ["jpg", "png", "webp"]:
|
||||
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -78,6 +86,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def load_image(self, file_path):
|
||||
frame = Image.open(file_path).convert("RGB")
|
||||
frame = self.crop_and_resize(frame)
|
||||
first_frame = frame
|
||||
frame = self.frame_process(frame)
|
||||
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||
return frame
|
||||
@@ -87,10 +96,16 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
text = self.text[data_id]
|
||||
path = self.path[data_id]
|
||||
if self.is_image(path):
|
||||
if self.is_i2v:
|
||||
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||
video = self.load_image(path)
|
||||
else:
|
||||
video = self.load_video(path)
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
if self.is_i2v:
|
||||
video, first_frame = video
|
||||
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||
else:
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
return data
|
||||
|
||||
|
||||
@@ -100,21 +115,35 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class LightningModelForDataProcess(pl.LightningModule):
|
||||
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
super().__init__()
|
||||
model_path = [text_encoder_path, vae_path]
|
||||
if image_encoder_path is not None:
|
||||
model_path.append(image_encoder_path)
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([text_encoder_path, vae_path])
|
||||
model_manager.load_models(model_path)
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
||||
|
||||
self.pipe.device = self.device
|
||||
if video is not None:
|
||||
# prompt
|
||||
prompt_emb = self.pipe.encode_prompt(text)
|
||||
# video
|
||||
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb}
|
||||
# image
|
||||
if "first_frame" in batch:
|
||||
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||
_, _, num_frames, height, width = video.shape
|
||||
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||
torch.save(data, path + ".tensors.pth")
|
||||
|
||||
|
||||
@@ -145,10 +174,21 @@ class TensorDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class LightningModelForTrain(pl.LightningModule):
|
||||
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
|
||||
def __init__(
|
||||
self,
|
||||
dit_path,
|
||||
learning_rate=1e-5,
|
||||
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||
pretrained_lora_path=None
|
||||
):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([dit_path])
|
||||
if os.path.isfile(dit_path):
|
||||
model_manager.load_models([dit_path])
|
||||
else:
|
||||
dit_path = dit_path.split(",")
|
||||
model_manager.load_models([dit_path])
|
||||
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
@@ -167,6 +207,7 @@ class LightningModelForTrain(pl.LightningModule):
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
@@ -210,24 +251,30 @@ class LightningModelForTrain(pl.LightningModule):
|
||||
# Data
|
||||
latents = batch["latents"].to(self.device)
|
||||
prompt_emb = batch["prompt_emb"]
|
||||
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
|
||||
|
||||
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||
image_emb = batch["image_emb"]
|
||||
if "clip_feature" in image_emb:
|
||||
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||
if "y" in image_emb:
|
||||
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||
|
||||
# Loss
|
||||
self.pipe.device = self.device
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
noise_pred = self.pipe.denoising_model()(
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
noise_pred = self.pipe.denoising_model()(
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
@@ -282,6 +329,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="Path of text encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of image encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
@@ -410,6 +463,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing_offload",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing offload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_architecture",
|
||||
type=str,
|
||||
@@ -446,7 +505,8 @@ def data_process(args):
|
||||
frame_interval=1,
|
||||
num_frames=args.num_frames,
|
||||
height=args.height,
|
||||
width=args.width
|
||||
width=args.width,
|
||||
is_i2v=args.image_encoder_path is not None
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
@@ -456,6 +516,7 @@ def data_process(args):
|
||||
)
|
||||
model = LightningModelForDataProcess(
|
||||
text_encoder_path=args.text_encoder_path,
|
||||
image_encoder_path=args.image_encoder_path,
|
||||
vae_path=args.vae_path,
|
||||
tiled=args.tiled,
|
||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||
@@ -490,6 +551,7 @@ def train(args):
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
pretrained_lora_path=args.pretrained_lora_path,
|
||||
)
|
||||
if args.use_swanlab:
|
||||
@@ -501,7 +563,7 @@ def train(args):
|
||||
name="wan",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=args.output_path,
|
||||
logdir=os.path.join(args.output_path, "swanlog"),
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
@@ -510,6 +572,7 @@ def train(args):
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision="bf16",
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
|
||||
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True,
|
||||
# TeaCache parameters
|
||||
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
|
||||
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# TeaCache doesn't support video-to-video
|
||||
@@ -9,6 +9,10 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
|
||||
torch_dtype=torch.float32, # Image Encoder is loaded with float32
|
||||
)
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
@@ -20,14 +24,13 @@ model_manager.load_models(
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
|
||||
# Download example image
|
||||
dataset_snapshot_download(
|
||||
|
||||
149
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
149
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
import lightning as pl
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput
|
||||
from torch.distributed._tensor import Replicate, Shard
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from lightning.pytorch.strategies import ModelParallelStrategy
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video
|
||||
from tqdm import tqdm
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
|
||||
class ToyDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, tasks=[]):
|
||||
self.tasks = tasks
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
return self.tasks[data_id]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tasks)
|
||||
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
def configure_model(self):
|
||||
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||
plan = {
|
||||
"text_embedding.0": ColwiseParallel(),
|
||||
"text_embedding.2": RowwiseParallel(),
|
||||
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
|
||||
"text_embedding.0": ColwiseParallel(),
|
||||
"text_embedding.2": RowwiseParallel(),
|
||||
"blocks.0": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None, None, None),
|
||||
desired_input_layouts=(Replicate(), None, None, None),
|
||||
),
|
||||
"head": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None),
|
||||
desired_input_layouts=(Replicate(), None),
|
||||
use_local_output=True,
|
||||
)
|
||||
}
|
||||
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
|
||||
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||
layer_tp_plan = {
|
||||
"self_attn": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Replicate()),
|
||||
desired_input_layouts=(Shard(1), Shard(0)),
|
||||
),
|
||||
"self_attn.q": SequenceParallel(),
|
||||
"self_attn.k": SequenceParallel(),
|
||||
"self_attn.v": SequenceParallel(),
|
||||
"self_attn.norm_q": SequenceParallel(),
|
||||
"self_attn.norm_k": SequenceParallel(),
|
||||
"self_attn.attn": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()),
|
||||
|
||||
"cross_attn": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Replicate()),
|
||||
desired_input_layouts=(Shard(1), Replicate()),
|
||||
),
|
||||
"cross_attn.q": SequenceParallel(),
|
||||
"cross_attn.k": SequenceParallel(),
|
||||
"cross_attn.v": SequenceParallel(),
|
||||
"cross_attn.norm_q": SequenceParallel(),
|
||||
"cross_attn.norm_k": SequenceParallel(),
|
||||
"cross_attn.attn": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False),
|
||||
|
||||
"ffn.0": ColwiseParallel(input_layouts=Shard(1)),
|
||||
"ffn.2": RowwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"norm1": SequenceParallel(use_local_output=True),
|
||||
"norm2": SequenceParallel(use_local_output=True),
|
||||
"norm3": SequenceParallel(use_local_output=True),
|
||||
"gate": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
|
||||
)
|
||||
}
|
||||
parallelize_module(
|
||||
module=block,
|
||||
device_mesh=tp_mesh,
|
||||
parallelize_plan=layer_tp_plan,
|
||||
)
|
||||
|
||||
|
||||
def test_step(self, batch):
|
||||
data = batch[0]
|
||||
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
||||
output_path = data.pop("output_path")
|
||||
with torch.no_grad(), torch.inference_mode(False):
|
||||
video = self.pipe(**data)
|
||||
if self.local_rank == 0:
|
||||
save_video(video, output_path, fps=15, quality=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
ToyDataset([
|
||||
{
|
||||
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
"num_inference_steps": 50,
|
||||
"seed": 0,
|
||||
"tiled": False,
|
||||
"output_path": "video1.mp4",
|
||||
},
|
||||
{
|
||||
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
"num_inference_steps": 50,
|
||||
"seed": 1,
|
||||
"tiled": False,
|
||||
"output_path": "video2.mp4",
|
||||
},
|
||||
]),
|
||||
collate_fn=lambda x: x
|
||||
)
|
||||
model = LitModel()
|
||||
trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy())
|
||||
trainer.test(model, dataloader)
|
||||
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||
)
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
)
|
||||
from xfuser.core.distributed import (initialize_model_parallel,
|
||||
init_distributed_environment)
|
||||
init_distributed_environment(
|
||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=1,
|
||||
ulysses_degree=dist.get_world_size(),
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=f"cuda:{dist.get_rank()}",
|
||||
use_usp=True if dist.get_world_size() > 1 else False)
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
save_video(video, "video1.mp4", fps=25, quality=5)
|
||||
Reference in New Issue
Block a user