Compare commits

...

51 Commits

Author SHA1 Message Date
Artiprocher
05094710e3 support motion controller 2025-03-24 19:07:58 +08:00
Artiprocher
105eaf0f49 controlnet 2025-03-21 11:09:56 +08:00
Artiprocher
6cd032e846 skip bad files 2025-03-19 14:49:18 +08:00
Artiprocher
9d8130b48d ignore metadata 2025-03-19 11:36:07 +08:00
Artiprocher
ce848a3d1a bugfix 2025-03-18 19:36:58 +08:00
Artiprocher
a8ce9fef33 support redirected tensor path 2025-03-18 19:24:27 +08:00
Artiprocher
8da0d183a2 support target fps 2025-03-18 17:31:15 +08:00
Artiprocher
4b2b3dda94 support target fps 2025-03-18 17:30:13 +08:00
Artiprocher
b1fabbc6b0 skip bad videos 2025-03-18 17:24:39 +08:00
Zhongjie Duan
e28c246bcc Merge pull request #457 from modelscope/wan-tp
support wan tensor parallel (preview)
2025-03-17 19:53:17 +08:00
Artiprocher
04d03500ff support wan tensor parallel (preview) 2025-03-17 19:39:45 +08:00
Zhongjie Duan
39890f023f Merge pull request #448 from modelscope/wan-teacache
support teacache in wan
2025-03-14 18:21:20 +08:00
Artiprocher
e425753f79 support teacache in wan 2025-03-14 17:45:52 +08:00
Zhongjie Duan
ca40074d72 Merge pull request #447 from modelscope/lora
Lora
2025-03-14 15:34:22 +08:00
Artiprocher
1fd3d67379 improve lora loading efficiency 2025-03-14 15:15:37 +08:00
Artiprocher
3acd9c73be improve lora loading efficiency 2025-03-14 15:05:54 +08:00
Zhongjie Duan
32422b49ee Merge pull request #436 from mi804/hunyuanvideo_i2v
support hunyuanvideo-i2v
2025-03-13 19:38:11 +08:00
Furkan Gözükara
5c4d3185fb Merge branch 'modelscope:main' into main 2025-03-13 14:22:34 +03:00
Zhongjie Duan
762bcbee58 Merge pull request #444 from modelscope/wan-itv-train
Wan itv train
2025-03-13 15:40:51 +08:00
Zhongjie Duan
6b411ada16 Merge branch 'main' into wan-itv-train 2025-03-13 15:24:59 +08:00
Artiprocher
a25bd74d8b support wan i2v training 2025-03-13 15:14:10 +08:00
Furkan Gözükara
fb5fc09bad Made much much faster than before
enable debug to see every message
2025-03-13 02:30:42 +03:00
Furkan Gözükara
3fdba19e02 Fixes high RAM usage Wan 2.1
Fixes high RAM usage Wan 2.1
2025-03-12 15:49:57 +03:00
mi804
4bec2983a9 support hunyuanvideo_i2v 2025-03-11 16:20:09 +08:00
Zhongjie Duan
03ea27893f Merge pull request #431 from modelscope/wan-update
Wan update
2025-03-10 18:26:32 +08:00
Artiprocher
718b45f2af bugfix 2025-03-10 18:25:23 +08:00
Zhongjie Duan
63a79eeb2a Merge pull request #426 from Zeyi-Lin/main
Modify the swanlab `logdir` location
2025-03-10 17:59:17 +08:00
Artiprocher
e757013a14 vram optimization 2025-03-10 17:47:14 +08:00
Artiprocher
a05f647633 vram optimization 2025-03-10 17:11:11 +08:00
ZeYi Lin
7604be0301 output_path join swanlog 2025-03-08 13:57:08 +08:00
mi804
945b43492e load hunyuani2v model 2025-03-07 17:43:30 +08:00
Artiprocher
b548d7caf2 refactor wan dit 2025-03-07 16:35:26 +08:00
Zhongjie Duan
6e316fd825 Merge pull request #421 from modelscope/wan-update
support diffusers format wan and other lora
2025-03-06 17:41:36 +08:00
Artiprocher
84fb61aaaf support diffusers format wan and other lora 2025-03-06 17:40:21 +08:00
Zhongjie Duan
50a9946b57 Merge pull request #419 from modelscope/wan-update
wan image encoder to fp16
2025-03-06 16:28:55 +08:00
Artiprocher
384d1a8198 wan image encoder to fp16 2025-03-06 16:28:23 +08:00
Zhongjie Duan
a58c193d0c Merge pull request #412 from boopage/patch-1
Update train_wan_t2v.py - include .jpeg for image detection
2025-03-06 12:46:43 +08:00
boopage
34a5ef8c15 Update train_wan_t2v.py
Included .jpeg extension for image type detection, preventing an error trying to the read image as a video format
2025-03-05 11:13:11 +01:00
Zhongjie Duan
41e3e4e157 Merge pull request #410 from mi804/dreambooth_lora
support dreambooth lora
2025-03-05 11:48:00 +08:00
mi804
e576d71908 support dreambooth lora 2025-03-05 11:20:10 +08:00
Zhongjie Duan
906aadbf1b Merge pull request #404 from modelscope/wan-examples-update
update wan examples
2025-03-04 21:54:33 +08:00
Artiprocher
bf0bf2d5ba update wan examples 2025-03-04 21:54:04 +08:00
Zhongjie Duan
fe0fff1399 Merge pull request #401 from modelscope/flux-diffusers
support load flux from diffusers
2025-03-04 20:52:07 +08:00
Artiprocher
50fceb84d2 support load flux from diffusers 2025-03-04 20:38:25 +08:00
Zhongjie Duan
100da41034 Merge pull request #400 from mi804/eligen
update eligen model from huggingface
2025-03-04 20:11:18 +08:00
mi804
c382237833 update eligen from huggingface 2025-03-04 20:04:24 +08:00
Zhongjie Duan
98ac191750 Merge pull request #398 from modelscope/reduce_dependency
reduce dependency
2025-03-04 16:22:29 +08:00
Artiprocher
2f73dbe7a3 reduce dependency 2025-03-04 15:21:00 +08:00
wang96
490d420d82 fix bugs 2025-02-27 15:26:39 +08:00
wang96
0aca943a39 Merge remote-tracking branch 'upstream/main' 2025-02-27 15:23:55 +08:00
wang96
0dbb3d333f feat: support I2V training 2025-02-26 19:50:59 +08:00
33 changed files with 3235 additions and 919 deletions

View File

@@ -19,7 +19,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
@@ -36,6 +36,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
@@ -43,7 +44,7 @@ Until now, DiffSynth Studio has supported the following models:
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)

View File

@@ -95,6 +95,7 @@ model_loader_configs = [
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
@@ -116,6 +117,7 @@ model_loader_configs = [
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
@@ -133,6 +135,7 @@ huggingface_model_loader_configs = [
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
]
patch_model_loader_configs = [
@@ -675,6 +678,25 @@ preset_models_on_modelscope = {
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
},
"HunyuanVideoI2V":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
],
"load_path": [
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
},
"HunyuanVideo-fp8":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
@@ -751,4 +773,5 @@ Preset_model_id: TypeAlias = Literal[
"StableDiffusion3.5-medium",
"HunyuanVideo",
"HunyuanVideo-fp8",
"HunyuanVideoI2V",
]

View File

@@ -1,10 +1,4 @@
from typing_extensions import Literal, TypeAlias
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
)
Processor_id: TypeAlias = Literal[
@@ -15,18 +9,25 @@ class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if not skip_processor:
if processor_id == "canny":
from controlnet_aux.processor import CannyDetector
self.processor = CannyDetector()
elif processor_id == "depth":
from controlnet_aux.processor import MidasDetector
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
from controlnet_aux.processor import HEDdetector
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
from controlnet_aux.processor import LineartDetector
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
from controlnet_aux.processor import LineartAnimeDetector
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
from controlnet_aux.processor import OpenposeDetector
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "normal":
from controlnet_aux.processor import NormalBaeDetector
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None

View File

@@ -628,19 +628,22 @@ class FluxDiTStateDictConverter:
else:
pass
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
mlp = torch.zeros(4 * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:

View File

@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
from einops import rearrange, repeat
from tqdm import tqdm
from typing import Union, Tuple, List
from .utils import hash_state_dict_keys
def HunyuanVideoRope(latents):
@@ -236,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
return x
class SingleTokenRefiner(torch.nn.Module):
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
@@ -269,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
x = block(x, c, mask)
return x
class ModulateDiT(torch.nn.Module):
def __init__(self, hidden_size, factor=6):
@@ -279,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
def forward(self, x):
return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
if tr_shift is not None:
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
if scale is None and shift is None:
return x
elif shift is None:
@@ -290,7 +296,7 @@ def modulate(x, shift=None, scale=None):
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def reshape_for_broadcast(
freqs_cis,
@@ -343,7 +349,7 @@ def rotate_half(x):
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
@@ -385,6 +391,15 @@ def attention(q, k, v):
return x
def apply_gate(x, gate, tr_gate=None, tr_token=None):
if tr_gate is not None:
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
return torch.concat((x_zero, x_orig), dim=1)
else:
return x * gate.unsqueeze(1)
class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
)
def forward(self, hidden_states, conditioning, freqs_cis=None):
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
else:
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -418,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
def process_ff(self, hidden_states, attn_output, mod):
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
if mod_tr is not None:
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
else:
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
return hidden_states
class MMDoubleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b
@@ -488,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
return x + output * mod_gate.unsqueeze(1)
class MMSingleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
)
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
else:
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
v_len = txt_len - split_token
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
return hidden_states
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
class HunyuanVideoDiT(torch.nn.Module):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
super().__init__()
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
self.final_layer = FinalLayer(hidden_size)
@@ -580,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
def unpatchify(self, x, T, H, W):
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
return x
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
self.warm_device = warm_device
self.cold_device = cold_device
@@ -610,10 +642,12 @@ class HunyuanVideoDiT(torch.nn.Module):
):
B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
if self.guidance_in is not None:
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask)
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
@@ -625,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
img = self.final_layer(img, vec)
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
def cast_to(weight, dtype=None, device=None, copy=False):
@@ -681,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
del x_, weight_, bias_
torch.cuda.empty_cache()
return y_
def block_forward(self, x, **kwargs):
# This feature can only reduce 2GB VRAM, so we disable it.
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
@@ -689,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
for j in range((self.out_features + self.block_size - 1) // self.block_size):
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
return y
def forward(self, x, **kwargs):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.linear(x, weight, bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
super().__init__()
self.module = module
self.dtype = dtype
self.device = device
def forward(self, hidden_states, **kwargs):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
@@ -711,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
hidden_states = hidden_states * weight
return hidden_states
class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
if self.weight is not None and self.bias is not None:
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
if "module" in state_dict:
state_dict = state_dict["module"]
direct_dict = {
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
state_dict_[name_] = param
else:
pass
return state_dict_

View File

@@ -1,24 +1,18 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
from copy import deepcopy
import torch
class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
def forward(
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids)
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
break
return hidden_states
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
# TODO: implement the low VRAM inference for MLLM.
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
outputs = super().forward(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
pixel_values=pixel_values)
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
return hidden_state

View File

@@ -195,70 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
"txt.mod": "txt_mod",
}
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
state_dict_ = {}
for key in state_dict:
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
target_name = ".".join(keys)
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def match(self, model: torch.nn.Module, state_dict_lora):
lora_name_dict = self.get_name_dict(state_dict_lora)
model_name_dict = {name: None for name, _ in model.named_parameters()}
matched_num = sum([i in model_name_dict for i in lora_name_dict])
if matched_num == len(lora_name_dict):
return "", ""
else:
return None
def fetch_device_and_dtype(self, state_dict):
device, dtype = None, None
for name, param in state_dict.items():
device, dtype = param.device, param.dtype
break
computation_device = device
computation_dtype = dtype
if computation_device == torch.device("cpu"):
if torch.cuda.is_available():
computation_device = torch.device("cuda")
if computation_dtype == torch.float8_e4m3fn:
computation_dtype = torch.float32
return device, dtype, computation_device, computation_dtype
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model)
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
lora_name_dict = self.get_name_dict(state_dict_lora)
for name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
weight_patched = weight_model + weight_lora
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
print(f" {len(lora_name_dict)} tensors are updated.")
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) > 0:
return "", ""
except:
pass
return None
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):

View File

@@ -376,6 +376,7 @@ class ModelManager:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}")
is_loaded = False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
@@ -385,7 +386,10 @@ class ModelManager:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
is_loaded = True
break
if not is_loaded:
print(f" Cannot load LoRA: {file_path}")
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):

View File

@@ -0,0 +1,204 @@
import torch
import torch.nn as nn
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d
from .utils import hash_state_dict_keys
class WanControlNetModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
for _ in range(num_layers)
])
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Linear(dim, dim, bias=False)
for _ in range(num_layers)
])
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
controlnet_conditioning: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x = x + self.controlnet_conv_in(controlnet_conditioning)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
res_stack = []
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
res_stack.append(x)
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
return controlnet_res_stack
@staticmethod
def state_dict_converter():
return WanControlNetModelStateDictConverter()
class WanControlNetModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict
def from_base_model(self, state_dict):
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else:
config = {}
state_dict_ = {}
dtype, device = None, None
for name, param in state_dict.items():
if name.startswith("head."):
continue
state_dict_[name] = param
dtype, device = param.dtype, param.device
for block_id in range(config["num_layers"]):
zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device)
state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone()
state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device)
state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device)
return state_dict_, config

File diff suppressed because it is too large Load Diff

View File

@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
return super().forward(x).type_as(x)
class SelfAttention(nn.Module):
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
# output
x = self.proj(x)
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
k, v = self.to_kv(x).chunk(2, dim=-1)
# compute attention
x = flash_attention(q, k, v, version=2)
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
x = x.reshape(b, 1, c)
# output
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
dtype = next(iter(self.model.visual.parameters())).dtype
videos = videos.to(dtype)
out = self.model.visual(videos, use_31_block=True)
return out

View File

@@ -0,0 +1,27 @@
import torch
import torch.nn as nn
from .wan_video_dit import sinusoidal_embedding_1d
class WanMotionControllerModel(torch.nn.Module):
def __init__(self, freq_dim=256, dim=1536):
super().__init__()
self.freq_dim = freq_dim
self.linear = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6),
)
def forward(self, motion_bucket_id):
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
emb = self.linear(emb)
return emb
def init(self):
state_dict = self.linear[-1].state_dict()
state_dict = {i: state_dict[i] * 0 for i in state_dict}
self.linear[-1].load_state_dict(state_dict)

View File

@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float().clamp_(-1, 1)
values = values.clamp_(-1, 1)
return values
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float()
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x.float()
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.float().clamp_(-1, 1)
return video.clamp_(-1, 1)
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):

View File

@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
import torchvision.transforms as transforms
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
pipe.enable_vram_management()
return pipe
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
target_height, target_width = closest_size
return semantic_image_pixel_values, target_height, target_width
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
prompt,
negative_prompt="",
input_video=None,
input_images=None,
i2v_resolution="720p",
i2v_stability=True,
denoising_strength=1.0,
seed=None,
rand_device=None,
@@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline):
):
# Tiler parameters
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# encoder input images
if input_images is not None:
self.load_models_to_device(['vae_encoder'])
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
image_latents = self.vae_encoder(image_pixel_values)
# Initialize noise
rand_device = self.device if rand_device is None else rand_device
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
elif input_images is not None and i2v_stability:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
t = torch.tensor([0.999]).to(device=self.device)
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
latents = latents.to(dtype=image_latents.dtype)
else:
latents = noise
# Encode prompts
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
# current mllm does not support vram_management
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
forward_func = lets_dance_hunyuan_video
if input_images is not None:
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
forward_func = lets_dance_hunyuan_video_i2v
# Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
self.load_models_to_device([] if self.vram_management else ["dit"])
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
if input_images is not None:
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
latents = torch.concat([image_latents, latents], dim=2)
else:
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae_decoder'])
@@ -194,7 +267,7 @@ class TeaCache:
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
@@ -203,14 +276,14 @@ class TeaCache:
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = img.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin))
img = x[:, :-256]
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
img = x[:, :-txt_len]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def lets_dance_hunyuan_video_i2v(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
# Uncomment below to keep same as official implementation
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
vec = dit.time_in(t, dtype=torch.bfloat16)
vec_2 = dit.vector_in(pooled_prompt_emb)
vec = vec + vec_2
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
tr_token = (H // 2) * (W // 2)
token_replace_vec = token_replace_vec + vec_2
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
img = x[:, :-txt_len]
if tea_cache is not None:
tea_cache.store(img)

View File

@@ -11,11 +11,14 @@ from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_controlnet import WanControlNetModel
from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -29,7 +32,9 @@ class WanVideoPipeline(BasePipeline):
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae']
self.controlnet: WanControlNetModel = None
self.motion_controller: WanMotionControllerModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet', 'motion_controller']
self.height_division_factor = 16
self.width_division_factor = 16
@@ -60,8 +65,7 @@ class WanVideoPipeline(BasePipeline):
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
WanLayerNorm: AutoWrappedModule,
WanRMSNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
@@ -116,7 +120,7 @@ class WanVideoPipeline(BasePipeline):
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_dtype=dtype,
computation_device=self.device,
),
)
@@ -153,17 +157,21 @@ class WanVideoPipeline(BasePipeline):
def encode_image(self, image, num_frames, height, width):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
y = torch.concat([msk, y])
return {"clip_fea": clip_context, "y": [y]}
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
y = y.to(dtype=self.torch_dtype, device=self.device)
return {"clip_feature": clip_context, "y": y}
def tensor2video(self, frames):
@@ -174,19 +182,27 @@ class WanVideoPipeline(BasePipeline):
def prepare_extra_input(self, latents=None):
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
return {}
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return {"controlnet_conditioning": controlnet_conditioning}
def prepare_motion_bucket_id(self, motion_bucket_id):
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
return {"motion_bucket_id": motion_bucket_id}
@torch.no_grad()
@@ -205,9 +221,13 @@ class WanVideoPipeline(BasePipeline):
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
motion_bucket_id=None,
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
tea_cache_l1_thresh=None,
tea_cache_model_id="",
controlnet_frames=None,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -221,15 +241,16 @@ class WanVideoPipeline(BasePipeline):
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
# Initialize noise
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
noise = noise.to(dtype=self.torch_dtype, device=self.device)
if input_video is not None:
self.load_models_to_device(['vae'])
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise
@@ -247,25 +268,53 @@ class WanVideoPipeline(BasePipeline):
else:
image_emb = {}
# ControlNet
if self.controlnet is not None and controlnet_frames is not None:
self.load_models_to_device(['vae', 'controlnet'])
controlnet_frames = self.preprocess_images(controlnet_frames)
controlnet_frames = torch.stack(controlnet_frames, dim=2).to(dtype=self.torch_dtype, device=self.device)
controlnet_kwargs = self.prepare_controlnet(controlnet_frames)
else:
controlnet_kwargs = {}
# Motion Controller
if self.motion_controller is not None and motion_bucket_id is not None:
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
else:
motion_kwargs = {}
# Extra input
extra_input = self.prepare_extra_input(latents)
# TeaCache
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(["dit"])
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
self.load_models_to_device(["dit", "controlnet", "motion_controller"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Inference
noise_pred_posi = model_fn_wan_video(
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **controlnet_kwargs, **motion_kwargs,
)
if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **controlnet_kwargs, **motion_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode
self.load_models_to_device(['vae'])
@@ -274,3 +323,145 @@ class WanVideoPipeline(BasePipeline):
frames = self.tensor2video(frames[0])
return frames
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
self.coefficients_dict = {
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
}
if model_id not in self.coefficients_dict:
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
self.coefficients = self.coefficients_dict[model_id]
def check(self, dit: WanModel, x, t_mod):
modulated_inp = t_mod.clone()
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = self.coefficients
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = x.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def model_fn_wan_video(
dit: WanModel,
controlnet: WanControlNetModel = None,
motion_controller: WanMotionControllerModel = None,
x: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
controlnet_conditioning: Optional[torch.Tensor] = None,
motion_bucket_id: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
# ControlNet
if controlnet is not None and controlnet_conditioning is not None:
controlnet_res_stack = controlnet(
x, timestep=timestep, context=context, clip_feature=clip_feature, y=y,
controlnet_conditioning=controlnet_conditioning,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
else:
controlnet_res_stack = None
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
if dit.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = dit.patchify(x)
freqs = torch.cat([
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, x, t_mod)
else:
tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update:
x = tea_cache.update(x)
else:
# blocks
for block_id, block in enumerate(dit.blocks):
if dit.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
if controlnet_res_stack is not None:
x = x + controlnet_res_stack[block_id]
if tea_cache is not None:
tea_cache.store(x)
x = dit.head(x, t)
x = dit.unpatchify(x, (f, h, w))
return x

View File

@@ -1,8 +1,9 @@
from .base_prompter import BasePrompter
from ..models.sd3_text_encoder import SD3TextEncoder1
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
import os, torch
from typing import Union
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE_ENCODE_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
"dit-llm-encode-i2v": {
"template": PROMPT_TEMPLATE_ENCODE_I2V,
"crop_start": 36,
"image_emb_start": 5,
"image_emb_end": 581,
"image_emb_len": 576,
"double_return_token_id": 271
},
"dit-llm-encode-video-i2v": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
"crop_start": 103,
"image_emb_start": 5,
"image_emb_end": 581,
"image_emb_len": 576,
"double_return_token_id": 271
},
}
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
@@ -56,9 +91,20 @@ class HunyuanVideoPrompter(BasePrompter):
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
def fetch_models(self,
text_encoder_1: SD3TextEncoder1 = None,
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
# processor
# TODO: may need to replace processor with local implementation
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
# template
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
def apply_text_to_template(self, text, template):
assert isinstance(template, str)
@@ -107,8 +153,89 @@ class HunyuanVideoPrompter(BasePrompter):
return last_hidden_state, attention_mask
def encode_prompt_using_mllm(self,
prompt,
images,
max_length,
device,
crop_start,
hidden_state_skip_layer=2,
use_attention_mask=True,
image_embed_interleave=4):
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
max_length += crop_start
inputs = self.tokenizer_2(prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
attention_mask=attention_mask,
hidden_state_skip_layer=hidden_state_skip_layer,
pixel_values=image_outputs)
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
batch_indices, last_double_return_token_indices = torch.where(
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
if last_double_return_token_indices.shape[0] == 3:
# in case the prompt is too long
last_double_return_token_indices = torch.cat((
last_double_return_token_indices,
torch.tensor([input_ids.shape[-1]]),
))
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
attention_mask_assistant_crop_end = last_double_return_token_indices
text_last_hidden_state = []
text_attention_mask = []
image_last_hidden_state = []
image_attention_mask = []
for i in range(input_ids.shape[0]):
text_last_hidden_state.append(
torch.cat([
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
last_hidden_state[i, assistant_crop_end[i].item():],
]))
text_attention_mask.append(
torch.cat([
attention_mask[
i,
crop_start:attention_mask_assistant_crop_start[i].item(),
],
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
]) if use_attention_mask else None)
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
image_attention_mask.append(
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
to(attention_mask.dtype) if use_attention_mask else None)
text_last_hidden_state = torch.stack(text_last_hidden_state)
text_attention_mask = torch.stack(text_attention_mask)
image_last_hidden_state = torch.stack(image_last_hidden_state)
image_attention_mask = torch.stack(image_attention_mask)
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
return last_hidden_state, attention_mask
def encode_prompt(self,
prompt,
images=None,
positive=True,
device="cuda",
clip_sequence_length=77,
@@ -116,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
data_type='video',
use_template=True,
hidden_state_skip_layer=2,
use_attention_mask=True):
use_attention_mask=True,
image_embed_interleave=4):
prompt = self.process_prompt(prompt, positive=positive)
@@ -136,8 +264,12 @@ class HunyuanVideoPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
# LLM
prompt_emb, attention_mask = self.encode_prompt_using_llm(
prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, use_attention_mask)
if images is None:
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, use_attention_mask)
else:
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
crop_start, hidden_state_skip_layer, use_attention_mask,
image_embed_interleave)
return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_emb = self.text_encoder(ids, mask)
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
for i, v in enumerate(seq_lens):
prompt_emb[:, v:] = 0
return prompt_emb

View File

@@ -37,7 +37,7 @@ class FlowMatchScheduler():
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final=False):
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())

View File

@@ -0,0 +1,45 @@
{
"_valid_processor_keys": [
"images",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format"
],
"crop_size": {
"height": 336,
"width": 336
},
"do_center_crop": true,
"do_convert_rgb": true,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_processor_type": "CLIPImageProcessor",
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"processor_class": "LlavaProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"shortest_edge": 336
}
}

View File

@@ -290,7 +290,7 @@ def launch_training_task(model, args):
name="diffsynth_studio",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
logdir=os.path.join(args.output_path, "swanlog"),
)
logger = [swanlab_logger]
else:

View File

@@ -6,7 +6,7 @@ We propose EliGen, a novel approach that leverages fine-grained entity-level inf
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
@@ -77,6 +77,11 @@ Demonstration of the styled entity control results with EliGen and IP-Adapter, s
|-|-|-|-|
|![image_1_base](https://github.com/user-attachments/assets/5e2dd3ab-37d3-4f58-8e02-ee2f9b238604)|![result1](https://github.com/user-attachments/assets/0f6711a2-572a-41b3-938a-95deff6d732d)|![result2](https://github.com/user-attachments/assets/ce2e66e5-1fdf-44e8-bca7-555d805a50b1)|![result3](https://github.com/user-attachments/assets/ad2da233-2f7c-4065-ab57-b2d84dc2c0e2)|
We also provide a demo of the styled entity control results with EliGen and specific styled lora, see [./styled_entity_control.py](./styled_entity_control.py) for details. Here is the visualization of EliGen with [Lego dreambooth lora](https://huggingface.co/merve/flux-lego-lora-dreambooth).
|![image_1_base](https://github.com/user-attachments/assets/35fb60f5-48ef-4f22-95d8-f9e732a5f63f)|![result1](https://github.com/user-attachments/assets/441d700f-f0b1-40e0-8848-4db23520972c)|![result2](https://github.com/user-attachments/assets/c8fd4498-3c55-48ab-9abf-3a092a90c878)|![result3](https://github.com/user-attachments/assets/181ba2bb-62cf-41a8-9e3a-20ed8a7a672f)|
|-|-|-|-|
|![image_1_base](https://github.com/user-attachments/assets/70a3f578-8c7e-4b40-954d-8fc94d4f3ae9)|![result1](https://github.com/user-attachments/assets/65670717-6136-4594-84e5-2307fc20753d)|![result2](https://github.com/user-attachments/assets/5ec7a5bd-f2c9-4b2e-8a4e-d2655ec8036c)|![result3](https://github.com/user-attachments/assets/56f00192-9553-45a6-a971-511b9f5b1480)|
### Entity Transfer
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.

View File

@@ -27,11 +27,20 @@ def example(pipe, seeds, example_id, global_prompt, entity_prompts):
# download and load model
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
# set download_from_modelscope = False if you want to download model from huggingface
download_from_modelscope = True
if download_from_modelscope:
model_id = "DiffSynth-Studio/Eligen"
downloading_priority = ["ModelScope"]
else:
model_id = "modelscope/EliGen"
downloading_priority = ["HuggingFace"]
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
model_id=model_id,
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
local_dir="models/lora/entity_control",
downloading_priority=downloading_priority
),
lora_alpha=1
)

View File

@@ -0,0 +1,90 @@
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
from examples.EntityControl.utils import visualize_masks
from PIL import Image
import torch
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
for seed in seeds:
# generate image
image = pipe(
prompt=global_prompt,
cfg_scale=3.0,
negative_prompt=negative_prompt,
num_inference_steps=50,
embedded_guidance=3.5,
seed=seed,
height=1024,
width=1024,
eligen_entity_prompts=entity_prompts,
eligen_entity_masks=masks,
)
image.save(f"styled_eligen_example_{example_id}_{seed}.png")
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_example_{example_id}_mask_{seed}.png")
# download and load model
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(
download_customized_models(
model_id="FluxLora/merve-flux-lego-lora-dreambooth",
origin_file_path="pytorch_lora_weights.safetensors",
local_dir="models/lora/merve-flux-lego-lora-dreambooth"
),
lora_alpha=1
)
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
),
lora_alpha=1
)
pipe = FluxImagePipeline.from_model_manager(model_manager)
# example 1
trigger_word = "lego set in style of TOK, "
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
global_prompt = trigger_word + global_prompt
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
example(pipe, [0], 1, global_prompt, entity_prompts)
# example 2
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
global_prompt = trigger_word + global_prompt
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
example(pipe, [0], 2, global_prompt, entity_prompts)
# example 3
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
global_prompt = trigger_word + global_prompt
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
example(pipe, [27], 3, global_prompt, entity_prompts)
# example 4
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
global_prompt = trigger_word + global_prompt
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
example(pipe, [21], 4, global_prompt, entity_prompts)
# example 5
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
global_prompt = trigger_word + global_prompt
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
example(pipe, [0], 5, global_prompt, entity_prompts)
# example 6
global_prompt = "Snow White and the 6 Dwarfs."
global_prompt = trigger_word + global_prompt
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
example(pipe, [8], 6, global_prompt, entity_prompts)
# example 7, same prompt with different seeds
seeds = range(5, 9)
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
global_prompt = trigger_word + global_prompt
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
example(pipe, seeds, 7, global_prompt, entity_prompts)

View File

@@ -8,6 +8,12 @@
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|VRAM required|Example script|Frames|Resolution|Note|
|-|-|-|-|-|
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
## Gallery
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a

View File

@@ -0,0 +1,43 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cpu"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=True)
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)

View File

@@ -0,0 +1,45 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cuda"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cuda"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)

View File

@@ -31,6 +31,8 @@ Put sunglasses on the dog.
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
### Wan-Video-14B-T2V
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
@@ -47,6 +49,8 @@ We present a detailed table here. The model is tested on a single A100.
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
### Wan-Video-14B-I2V
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
@@ -155,6 +159,12 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--use_gradient_checkpointing
```
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`.
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
Step 5: Test
Test LoRA:

View File

@@ -7,13 +7,17 @@ from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
if os.path.exists(os.path.join(base_path, "train")):
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
else:
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
@@ -21,6 +25,8 @@ class TextVideoDataset(torch.utils.data.Dataset):
self.num_frames = num_frames
self.height = height
self.width = width
self.is_i2v = is_i2v
self.target_fps = target_fps
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
@@ -48,10 +54,13 @@ class TextVideoDataset(torch.utils.data.Dataset):
return None
frames = []
first_frame = None
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
if first_frame is None:
first_frame = np.array(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
@@ -59,18 +68,28 @@ class TextVideoDataset(torch.utils.data.Dataset):
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
return frames
if self.is_i2v:
return frames, first_frame
else:
return frames
def load_video(self, file_path):
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
start_frame_id = 0
if self.target_fps is None:
frame_interval = self.frame_interval
else:
reader = imageio.get_reader(file_path)
fps = reader.get_meta_data()["fps"]
reader.close()
frame_interval = max(round(fps / self.target_fps), 1)
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
if file_ext_name.lower() in ["jpg", "png", "webp"]:
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
return True
return False
@@ -78,6 +97,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
def load_image(self, file_path):
frame = Image.open(file_path).convert("RGB")
frame = self.crop_and_resize(frame)
first_frame = frame
frame = self.frame_process(frame)
frame = rearrange(frame, "C H W -> C 1 H W")
return frame
@@ -86,11 +106,20 @@ class TextVideoDataset(torch.utils.data.Dataset):
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
if self.is_image(path):
video = self.load_image(path)
else:
video = self.load_video(path)
data = {"text": text, "video": video, "path": path}
try:
if self.is_image(path):
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
if self.is_i2v:
video, first_frame = video
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path}
except:
data = None
return data
@@ -100,43 +129,82 @@ class TextVideoDataset(torch.utils.data.Dataset):
class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
super().__init__()
model_path = [text_encoder_path, vae_path]
if image_encoder_path is not None:
model_path.append(image_encoder_path)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([text_encoder_path, vae_path])
model_manager.load_models(model_path)
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
self.redirected_tensor_path = redirected_tensor_path
def test_step(self, batch, batch_idx):
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
data = batch[0]
if data is None or data["video"] is None:
return
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
self.pipe.device = self.device
if video is not None:
# prompt
prompt_emb = self.pipe.encode_prompt(text)
# video
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
data = {"latents": latents, "prompt_emb": prompt_emb}
# image
if "first_frame" in batch:
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
_, _, num_frames, height, width = video.shape
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
else:
image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
if self.redirected_tensor_path is not None:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(self.redirected_tensor_path, path)
torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
if os.path.exists(metadata_path):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
if redirected_tensor_path is None:
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
else:
cached_path = []
for path in self.path:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(redirected_tensor_path, path)
if os.path.exists(path + ".tensors.pth"):
cached_path.append(path + ".tensors.pth")
self.path = cached_path
else:
print("Cannot find metadata.csv. Trying to search for tensor files.")
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch
self.redirected_tensor_path = redirected_tensor_path
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
while True:
try:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self):
@@ -145,10 +213,21 @@ class TensorDataset(torch.utils.data.Dataset):
class LightningModelForTrain(pl.LightningModule):
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
def __init__(
self,
dit_path,
learning_rate=1e-5,
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
pretrained_lora_path=None
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([dit_path])
if os.path.isfile(dit_path):
model_manager.load_models([dit_path])
else:
dit_path = dit_path.split(",")
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -167,6 +246,7 @@ class LightningModelForTrain(pl.LightningModule):
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def freeze_parameters(self):
@@ -210,24 +290,30 @@ class LightningModelForTrain(pl.LightningModule):
# Data
latents = batch["latents"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
image_emb = batch["image_emb"]
if "clip_feature" in image_emb:
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
if "y" in image_emb:
image_emb["y"] = image_emb["y"][0].to(self.device)
# Loss
self.pipe.device = self.device
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
noise_pred = self.pipe.denoising_model()(
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
noise_pred = self.pipe.denoising_model()(
noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
@@ -276,12 +362,30 @@ def parse_args():
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--metadata_path",
type=str,
default=None,
help="Path to metadata.csv.",
)
parser.add_argument(
"--redirected_tensor_path",
type=str,
default=None,
help="Path to save cached tensors.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help="Path of text encoder.",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
help="Path of image encoder.",
)
parser.add_argument(
"--vae_path",
type=str,
@@ -336,6 +440,12 @@ def parse_args():
default=81,
help="Number of frames.",
)
parser.add_argument(
"--target_fps",
type=int,
default=None,
help="Expected FPS for sampling frames.",
)
parser.add_argument(
"--height",
type=int,
@@ -410,6 +520,12 @@ def parse_args():
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to use gradient checkpointing offload.",
)
parser.add_argument(
"--train_architecture",
type=str,
@@ -441,25 +557,30 @@ def parse_args():
def data_process(args):
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
height=args.height,
width=args.width
width=args.width,
is_i2v=args.image_encoder_path is not None,
target_fps=args.target_fps,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers
num_workers=args.dataloader_num_workers,
collate_fn=lambda x: x,
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,
image_encoder_path=args.image_encoder_path,
vae_path=args.vae_path,
tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width),
redirected_tensor_path=args.redirected_tensor_path,
)
trainer = pl.Trainer(
accelerator="gpu",
@@ -472,8 +593,9 @@ def data_process(args):
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
steps_per_epoch=args.steps_per_epoch,
redirected_tensor_path=args.redirected_tensor_path,
)
dataloader = torch.utils.data.DataLoader(
dataset,
@@ -490,6 +612,7 @@ def train(args):
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
pretrained_lora_path=args.pretrained_lora_path,
)
if args.use_swanlab:
@@ -501,7 +624,7 @@ def train(args):
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
logdir=os.path.join(args.output_path, "swanlog"),
)
logger = [swanlab_logger]
else:
@@ -510,6 +633,7 @@ def train(args):
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,

View File

@@ -0,0 +1,626 @@
import torch, os, imageio, argparse
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
from diffsynth.models.wan_video_controlnet import WanControlNetModel
from diffsynth.pipelines.wan_video import model_fn_wan_video
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
self.controlnet_path = [os.path.join(base_path, file_name) for file_name in metadata["controlnet_file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
self.frame_interval = frame_interval
self.num_frames = num_frames
self.height = height
self.width = width
self.is_i2v = is_i2v
self.target_fps = target_fps
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def crop_and_resize(self, image):
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
first_frame = None
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
if first_frame is None:
first_frame = np.array(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
if self.is_i2v:
return frames, first_frame
else:
return frames
def load_video(self, file_path):
start_frame_id = 0
if self.target_fps is None:
frame_interval = self.frame_interval
else:
reader = imageio.get_reader(file_path)
fps = reader.get_meta_data()["fps"]
reader.close()
frame_interval = max(round(fps / self.target_fps), 1)
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
return True
return False
def load_image(self, file_path):
frame = Image.open(file_path).convert("RGB")
frame = self.crop_and_resize(frame)
frame = self.frame_process(frame)
frame = rearrange(frame, "C H W -> C 1 H W")
return frame
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
controlnet_path = self.controlnet_path[data_id]
try:
if self.is_image(path):
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
controlnet_frames = self.load_video(controlnet_path)
if self.is_i2v:
video, first_frame = video
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path, "controlnet_frames": controlnet_frames}
except:
data = None
return data
def __len__(self):
return len(self.path)
class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
super().__init__()
model_path = [text_encoder_path, vae_path]
if image_encoder_path is not None:
model_path.append(image_encoder_path)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models(model_path)
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
self.redirected_tensor_path = redirected_tensor_path
def test_step(self, batch, batch_idx):
data = batch[0]
if data is None or data["video"] is None:
return
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
controlnet_frames = data["controlnet_frames"].unsqueeze(0)
self.pipe.device = self.device
if video is not None:
# prompt
prompt_emb = self.pipe.encode_prompt(text)
# video
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
# ControlNet video
controlnet_frames = controlnet_frames.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
controlnet_kwargs = self.pipe.prepare_controlnet(controlnet_frames, **self.tiler_kwargs)
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"][0]
# image
if "first_frame" in batch:
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
_, _, num_frames, height, width = video.shape
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
else:
image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb, "controlnet_kwargs": controlnet_kwargs}
if self.redirected_tensor_path is not None:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(self.redirected_tensor_path, path)
torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
if os.path.exists(metadata_path):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
if redirected_tensor_path is None:
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
else:
cached_path = []
for path in self.path:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(redirected_tensor_path, path)
if os.path.exists(path + ".tensors.pth"):
cached_path.append(path + ".tensors.pth")
self.path = cached_path
else:
print("Cannot find metadata.csv. Trying to search for tensor files.")
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch
self.redirected_tensor_path = redirected_tensor_path
def __getitem__(self, index):
while True:
try:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self):
return self.steps_per_epoch
class LightningModelForTrain(pl.LightningModule):
def __init__(
self,
dit_path,
learning_rate=1e-5,
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
pretrained_lora_path=None
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
if os.path.isfile(dit_path):
model_manager.load_models([dit_path])
else:
dit_path = dit_path.split(",")
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
state_dict = load_state_dict(dit_path, torch_dtype=torch.bfloat16)
state_dict, config = WanControlNetModel.state_dict_converter().from_base_model(state_dict)
self.pipe.controlnet = WanControlNetModel(**config).to(torch.bfloat16)
self.pipe.controlnet.load_state_dict(state_dict)
self.pipe.controlnet.train()
self.pipe.controlnet.requires_grad_(True)
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def training_step(self, batch, batch_idx):
# Data
latents = batch["latents"].to(self.device)
controlnet_kwargs = batch["controlnet_kwargs"]
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
image_emb = batch["image_emb"]
if "clip_feature" in image_emb:
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
if "y" in image_emb:
image_emb["y"] = image_emb["y"][0].to(self.device)
# Loss
self.pipe.device = self.device
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
noise_pred = model_fn_wan_video(
dit=self.pipe.dit, controlnet=self.pipe.controlnet,
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **controlnet_kwargs,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.controlnet.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.controlnet.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.controlnet.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--task",
type=str,
default="data_process",
required=True,
choices=["data_process", "train"],
help="Task. `data_process` or `train`.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--metadata_path",
type=str,
default=None,
help="Path to metadata.csv.",
)
parser.add_argument(
"--redirected_tensor_path",
type=str,
default=None,
help="Path to save cached tensors.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help="Path of text encoder.",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
help="Path of image encoder.",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path of VAE.",
)
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path of DiT.",
)
parser.add_argument(
"--tiled",
default=False,
action="store_true",
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
)
parser.add_argument(
"--tile_size_height",
type=int,
default=34,
help="Tile size (height) in VAE.",
)
parser.add_argument(
"--tile_size_width",
type=int,
default=34,
help="Tile size (width) in VAE.",
)
parser.add_argument(
"--tile_stride_height",
type=int,
default=18,
help="Tile stride (height) in VAE.",
)
parser.add_argument(
"--tile_stride_width",
type=int,
default=16,
help="Tile stride (width) in VAE.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--num_frames",
type=int,
default=81,
help="Number of frames.",
)
parser.add_argument(
"--target_fps",
type=int,
default=None,
help="Expected FPS for sampling frames.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=832,
help="Image width.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to use gradient checkpointing offload.",
)
parser.add_argument(
"--train_architecture",
type=str,
default="lora",
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
parser.add_argument(
"--pretrained_lora_path",
type=str,
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
args = parser.parse_args()
return args
def data_process(args):
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
height=args.height,
width=args.width,
is_i2v=args.image_encoder_path is not None,
target_fps=args.target_fps,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers,
collate_fn=lambda x: x,
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,
image_encoder_path=args.image_encoder_path,
vae_path=args.vae_path,
tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width),
redirected_tensor_path=args.redirected_tensor_path,
)
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
default_root_dir=args.output_path,
)
trainer.test(model, dataloader)
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
steps_per_epoch=args.steps_per_epoch,
redirected_tensor_path=args.redirected_tensor_path,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForTrain(
dit_path=args.dit_path,
learning_rate=args.learning_rate,
train_architecture=args.train_architecture,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
pretrained_lora_path=args.pretrained_lora_path,
)
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=os.path.join(args.output_path, "swanlog"),
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
args = parse_args()
if args.task == "data_process":
data_process(args)
elif args.task == "train":
train(args)

View File

@@ -0,0 +1,691 @@
import torch, os, imageio, argparse
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel
from diffsynth.pipelines.wan_video import model_fn_wan_video
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
from tqdm import tqdm
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
self.frame_interval = frame_interval
self.num_frames = num_frames
self.height = height
self.width = width
self.is_i2v = is_i2v
self.target_fps = target_fps
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def crop_and_resize(self, image):
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
first_frame = None
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
if first_frame is None:
first_frame = np.array(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
if self.is_i2v:
return frames, first_frame
else:
return frames
def load_video(self, file_path):
start_frame_id = 0
if self.target_fps is None:
frame_interval = self.frame_interval
else:
reader = imageio.get_reader(file_path)
fps = reader.get_meta_data()["fps"]
reader.close()
frame_interval = max(round(fps / self.target_fps), 1)
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
return True
return False
def load_image(self, file_path):
frame = Image.open(file_path).convert("RGB")
frame = self.crop_and_resize(frame)
first_frame = frame
frame = self.frame_process(frame)
frame = rearrange(frame, "C H W -> C 1 H W")
return frame
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
try:
if self.is_image(path):
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
if self.is_i2v:
video, first_frame = video
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path}
except:
data = None
return data
def __len__(self):
return len(self.path)
class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
super().__init__()
model_path = [text_encoder_path, vae_path]
if image_encoder_path is not None:
model_path.append(image_encoder_path)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models(model_path)
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
self.redirected_tensor_path = redirected_tensor_path
def test_step(self, batch, batch_idx):
data = batch[0]
if data is None or data["video"] is None:
return
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
self.pipe.device = self.device
if video is not None:
# prompt
prompt_emb = self.pipe.encode_prompt(text)
# video
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
# image
if "first_frame" in batch:
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
_, _, num_frames, height, width = video.shape
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
else:
image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
if self.redirected_tensor_path is not None:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(self.redirected_tensor_path, path)
torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
if os.path.exists(metadata_path):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
if redirected_tensor_path is None:
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
else:
cached_path = []
for path in self.path:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(redirected_tensor_path, path)
if os.path.exists(path + ".tensors.pth"):
cached_path.append(path + ".tensors.pth")
self.path = cached_path
else:
print("Cannot find metadata.csv. Trying to search for tensor files.")
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch
self.redirected_tensor_path = redirected_tensor_path
def __getitem__(self, index):
while True:
try:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self):
return self.steps_per_epoch
class LightningModelForTrain(pl.LightningModule):
def __init__(
self,
dit_path,
learning_rate=1e-5,
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
pretrained_lora_path=None
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
if os.path.isfile(dit_path):
model_manager.load_models([dit_path])
else:
dit_path = dit_path.split(",")
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.motion_controller = WanMotionControllerModel().to(torch.bfloat16)
self.pipe.motion_controller.init()
self.pipe.motion_controller.requires_grad_(True)
self.pipe.motion_controller.train()
self.motion_bucket_manager = MotionBucketManager()
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.dit.train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
# Upcast LoRA parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
def training_step(self, batch, batch_idx):
# Data
latents = batch["latents"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
image_emb = batch["image_emb"]
if "clip_feature" in image_emb:
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
if "y" in image_emb:
image_emb["y"] = image_emb["y"][0].to(self.device)
# Loss
self.pipe.device = self.device
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
motion_bucket_id = self.motion_bucket_manager(latents)
motion_bucket_kwargs = self.pipe.prepare_motion_bucket_id(motion_bucket_id)
# Compute loss
noise_pred = model_fn_wan_video(
dit=self.pipe.dit, motion_controller=self.pipe.motion_controller,
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **motion_bucket_kwargs,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.motion_controller.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_controller.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.motion_controller.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
checkpoint.update(lora_state_dict)
class MotionBucketManager:
def __init__(self):
self.thresholds = [
0.093750000, 0.094726562, 0.100585938, 0.100585938, 0.108886719, 0.109375000, 0.118652344, 0.127929688, 0.127929688, 0.130859375,
0.133789062, 0.137695312, 0.138671875, 0.138671875, 0.139648438, 0.143554688, 0.143554688, 0.147460938, 0.149414062, 0.149414062,
0.152343750, 0.153320312, 0.154296875, 0.154296875, 0.157226562, 0.163085938, 0.163085938, 0.164062500, 0.165039062, 0.166992188,
0.173828125, 0.179687500, 0.180664062, 0.184570312, 0.187500000, 0.188476562, 0.188476562, 0.189453125, 0.189453125, 0.202148438,
0.206054688, 0.210937500, 0.210937500, 0.211914062, 0.214843750, 0.214843750, 0.216796875, 0.216796875, 0.216796875, 0.218750000,
0.218750000, 0.221679688, 0.222656250, 0.227539062, 0.229492188, 0.230468750, 0.236328125, 0.243164062, 0.243164062, 0.245117188,
0.253906250, 0.253906250, 0.255859375, 0.259765625, 0.275390625, 0.275390625, 0.277343750, 0.279296875, 0.279296875, 0.279296875,
0.292968750, 0.292968750, 0.302734375, 0.306640625, 0.312500000, 0.312500000, 0.326171875, 0.330078125, 0.332031250, 0.332031250,
0.337890625, 0.343750000, 0.343750000, 0.351562500, 0.355468750, 0.357421875, 0.361328125, 0.367187500, 0.382812500, 0.388671875,
0.392578125, 0.392578125, 0.392578125, 0.404296875, 0.404296875, 0.425781250, 0.433593750, 0.507812500, 0.519531250, 0.539062500,
]
def get_motion_score(self, frames):
score = frames[:, :, 1:, :, :].std(dim=2).mean().tolist()
return score
def get_bucket_id(self, motion_score):
for bucket_id in range(len(self.thresholds) - 1):
if self.thresholds[bucket_id + 1] > motion_score:
return bucket_id
return len(self.thresholds)
def __call__(self, frames):
score = self.get_motion_score(frames)
bucket_id = self.get_bucket_id(score)
return bucket_id
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--task",
type=str,
default="data_process",
required=True,
choices=["data_process", "train"],
help="Task. `data_process` or `train`.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--metadata_path",
type=str,
default=None,
help="Path to metadata.csv.",
)
parser.add_argument(
"--redirected_tensor_path",
type=str,
default=None,
help="Path to save cached tensors.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help="Path of text encoder.",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
help="Path of image encoder.",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path of VAE.",
)
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path of DiT.",
)
parser.add_argument(
"--tiled",
default=False,
action="store_true",
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
)
parser.add_argument(
"--tile_size_height",
type=int,
default=34,
help="Tile size (height) in VAE.",
)
parser.add_argument(
"--tile_size_width",
type=int,
default=34,
help="Tile size (width) in VAE.",
)
parser.add_argument(
"--tile_stride_height",
type=int,
default=18,
help="Tile stride (height) in VAE.",
)
parser.add_argument(
"--tile_stride_width",
type=int,
default=16,
help="Tile stride (width) in VAE.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--num_frames",
type=int,
default=81,
help="Number of frames.",
)
parser.add_argument(
"--target_fps",
type=int,
default=None,
help="Expected FPS for sampling frames.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=832,
help="Image width.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to use gradient checkpointing offload.",
)
parser.add_argument(
"--train_architecture",
type=str,
default="lora",
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
parser.add_argument(
"--pretrained_lora_path",
type=str,
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
args = parser.parse_args()
return args
def data_process(args):
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
height=args.height,
width=args.width,
is_i2v=args.image_encoder_path is not None,
target_fps=args.target_fps,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers,
collate_fn=lambda x: x,
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,
image_encoder_path=args.image_encoder_path,
vae_path=args.vae_path,
tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width),
redirected_tensor_path=args.redirected_tensor_path,
)
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
default_root_dir=args.output_path,
)
trainer.test(model, dataloader)
def get_motion_thresholds(dataloader):
scores = []
for data in tqdm(dataloader):
scores.append(data["latents"][:, :, 1:, :, :].std(dim=2).mean().tolist())
scores = sorted(scores)
for i in range(100):
s = scores[int(i/100 * len(scores))]
print("%.9f" % s, end=", ")
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
steps_per_epoch=args.steps_per_epoch,
redirected_tensor_path=args.redirected_tensor_path,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForTrain(
dit_path=args.dit_path,
learning_rate=args.learning_rate,
train_architecture=args.train_architecture,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
pretrained_lora_path=args.pretrained_lora_path,
)
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=os.path.join(args.output_path, "swanlog"),
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
args = parse_args()
if args.task == "data_process":
data_process(args)
elif args.task == "train":
train(args)

View File

@@ -0,0 +1,34 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=0, tiled=True,
# TeaCache parameters
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
)
save_video(video, "video1.mp4", fps=15, quality=5)
# TeaCache doesn't support video-to-video

View File

@@ -9,6 +9,10 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
torch_dtype=torch.float32, # Image Encoder is loaded with float32
)
model_manager.load_models(
[
[
@@ -20,14 +24,13 @@ model_manager.load_models(
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
],
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Download example image
dataset_snapshot_download(

View File

@@ -0,0 +1,125 @@
import torch
import lightning as pl
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import parallelize_module
from lightning.pytorch.strategies import ModelParallelStrategy
from diffsynth import ModelManager, WanVideoPipeline, save_video
from tqdm import tqdm
from modelscope import snapshot_download
class ToyDataset(torch.utils.data.Dataset):
def __init__(self, tasks=[]):
self.tasks = tasks
def __getitem__(self, data_id):
return self.tasks[data_id]
def __len__(self):
return len(self.tasks)
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
[
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
],
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
],
torch_dtype=torch.bfloat16,
)
self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
def configure_model(self):
tp_mesh = self.device_mesh["tensor_parallel"]
for block_id, block in enumerate(self.pipe.dit.blocks):
layer_tp_plan = {
"self_attn": PrepareModuleInput(
input_layouts=(Replicate(), Replicate()),
desired_input_layouts=(Replicate(), Shard(0)),
),
"self_attn.q": SequenceParallel(),
"self_attn.k": SequenceParallel(),
"self_attn.v": SequenceParallel(),
"self_attn.norm_q": SequenceParallel(),
"self_attn.norm_k": SequenceParallel(),
"self_attn.attn": PrepareModuleInput(
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
),
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
"cross_attn": PrepareModuleInput(
input_layouts=(Replicate(), Replicate()),
desired_input_layouts=(Replicate(), Replicate()),
),
"cross_attn.q": SequenceParallel(),
"cross_attn.k": SequenceParallel(),
"cross_attn.v": SequenceParallel(),
"cross_attn.norm_q": SequenceParallel(),
"cross_attn.norm_k": SequenceParallel(),
"cross_attn.attn": PrepareModuleInput(
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
),
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
"ffn.0": ColwiseParallel(),
"ffn.2": RowwiseParallel(),
}
parallelize_module(
module=block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
def test_step(self, batch):
data = batch[0]
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
output_path = data.pop("output_path")
with torch.no_grad(), torch.inference_mode(False):
video = self.pipe(**data)
if self.local_rank == 0:
save_video(video, output_path, fps=15, quality=5)
if __name__ == "__main__":
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
dataloader = torch.utils.data.DataLoader(
ToyDataset([
{
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
"negative_prompt": "色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
"num_inference_steps": 50,
"seed": 0,
"tiled": False,
"output_path": "video1.mp4",
},
{
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
"negative_prompt": "色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
"num_inference_steps": 50,
"seed": 1,
"tiled": False,
"output_path": "video2.mp4",
},
]),
collate_fn=lambda x: x
)
model = LitModel()
trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy())
trainer.test(model, dataloader)