mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
51 Commits
v1.1.2
...
wan-models
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05094710e3 | ||
|
|
105eaf0f49 | ||
|
|
6cd032e846 | ||
|
|
9d8130b48d | ||
|
|
ce848a3d1a | ||
|
|
a8ce9fef33 | ||
|
|
8da0d183a2 | ||
|
|
4b2b3dda94 | ||
|
|
b1fabbc6b0 | ||
|
|
e28c246bcc | ||
|
|
04d03500ff | ||
|
|
39890f023f | ||
|
|
e425753f79 | ||
|
|
ca40074d72 | ||
|
|
1fd3d67379 | ||
|
|
3acd9c73be | ||
|
|
32422b49ee | ||
|
|
5c4d3185fb | ||
|
|
762bcbee58 | ||
|
|
6b411ada16 | ||
|
|
a25bd74d8b | ||
|
|
fb5fc09bad | ||
|
|
3fdba19e02 | ||
|
|
4bec2983a9 | ||
|
|
03ea27893f | ||
|
|
718b45f2af | ||
|
|
63a79eeb2a | ||
|
|
e757013a14 | ||
|
|
a05f647633 | ||
|
|
7604be0301 | ||
|
|
945b43492e | ||
|
|
b548d7caf2 | ||
|
|
6e316fd825 | ||
|
|
84fb61aaaf | ||
|
|
50a9946b57 | ||
|
|
384d1a8198 | ||
|
|
a58c193d0c | ||
|
|
34a5ef8c15 | ||
|
|
41e3e4e157 | ||
|
|
e576d71908 | ||
|
|
906aadbf1b | ||
|
|
bf0bf2d5ba | ||
|
|
fe0fff1399 | ||
|
|
50fceb84d2 | ||
|
|
100da41034 | ||
|
|
c382237833 | ||
|
|
98ac191750 | ||
|
|
2f73dbe7a3 | ||
|
|
490d420d82 | ||
|
|
0aca943a39 | ||
|
|
0dbb3d333f |
@@ -19,7 +19,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
@@ -36,6 +36,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
@@ -43,7 +44,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
|
||||
@@ -95,6 +95,7 @@ model_loader_configs = [
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
@@ -116,6 +117,7 @@ model_loader_configs = [
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
@@ -133,6 +135,7 @@ huggingface_model_loader_configs = [
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
@@ -675,6 +678,25 @@ preset_models_on_modelscope = {
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideoI2V":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
||||
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideo-fp8":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
@@ -751,4 +773,5 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"StableDiffusion3.5-medium",
|
||||
"HunyuanVideo",
|
||||
"HunyuanVideo-fp8",
|
||||
"HunyuanVideoI2V",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
@@ -15,18 +9,25 @@ class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||
if not skip_processor:
|
||||
if processor_id == "canny":
|
||||
from controlnet_aux.processor import CannyDetector
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
from controlnet_aux.processor import MidasDetector
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
from controlnet_aux.processor import HEDdetector
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
from controlnet_aux.processor import LineartDetector
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
from controlnet_aux.processor import LineartAnimeDetector
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
from controlnet_aux.processor import OpenposeDetector
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "normal":
|
||||
from controlnet_aux.processor import NormalBaeDetector
|
||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||
self.processor = None
|
||||
|
||||
@@ -628,19 +628,22 @@ class FluxDiTStateDictConverter:
|
||||
else:
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
|
||||
@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
@@ -236,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
@@ -269,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
@@ -279,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
|
||||
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||
if tr_shift is not None:
|
||||
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = torch.concat((x_zero, x_orig), dim=1)
|
||||
return x
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
@@ -290,7 +296,7 @@ def modulate(x, shift=None, scale=None):
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
@@ -343,7 +349,7 @@ def rotate_half(x):
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
@@ -385,6 +391,15 @@ def attention(q, k, v):
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||
if tr_gate is not None:
|
||||
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||
return torch.concat((x_zero, x_orig), dim=1)
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
@@ -405,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||
else:
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
@@ -418,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod):
|
||||
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
||||
if mod_tr is not None:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||
else:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -434,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
@@ -488,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
@@ -509,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||
else:
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
@@ -525,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
||||
v_len = txt_len - split_token
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -555,7 +587,7 @@ class FinalLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
@@ -565,7 +597,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
@@ -580,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
@@ -610,10 +642,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||
if self.guidance_in is not None:
|
||||
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
@@ -625,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
@@ -681,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
@@ -689,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
@@ -711,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
@@ -777,12 +811,12 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
@@ -882,4 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
# TODO: implement the low VRAM inference for MLLM.
|
||||
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||
outputs = super().forward(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
pixel_values=pixel_values)
|
||||
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
return hidden_state
|
||||
|
||||
@@ -195,70 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
"txt.mod": "txt_mod",
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||
device, torch_dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, torch_dtype = param.device, param.dtype
|
||||
break
|
||||
return device, torch_dtype
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
target_name = ".".join(keys)
|
||||
if target_name not in target_state_dict:
|
||||
return {}
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||
if matched_num == len(lora_name_dict):
|
||||
return "", ""
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_device_and_dtype(self, state_dict):
|
||||
device, dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, dtype = param.device, param.dtype
|
||||
break
|
||||
computation_device = device
|
||||
computation_dtype = dtype
|
||||
if computation_device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
computation_device = torch.device("cuda")
|
||||
if computation_dtype == torch.float8_e4m3fn:
|
||||
computation_dtype = torch.float32
|
||||
return device, dtype, computation_device, computation_dtype
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
model.load_state_dict(state_dict_model)
|
||||
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_patched = weight_model + weight_lora
|
||||
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for model_class in self.supported_model_classes:
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora_) > 0:
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
|
||||
@@ -376,6 +376,7 @@ class ModelManager:
|
||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||
else:
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
is_loaded = False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
@@ -385,7 +386,10 @@ class ModelManager:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
is_loaded = True
|
||||
break
|
||||
if not is_loaded:
|
||||
print(f" Cannot load LoRA: {file_path}")
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||
|
||||
204
diffsynth/models/wan_video_controlnet.py
Normal file
204
diffsynth/models/wan_video_controlnet.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
|
||||
class WanControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: Tuple[int, int, int],
|
||||
num_heads: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
self.blocks = nn.ModuleList([
|
||||
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
head_dim = dim // num_heads
|
||||
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||
|
||||
if has_image_input:
|
||||
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
|
||||
|
||||
self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
||||
self.controlnet_blocks = torch.nn.ModuleList([
|
||||
torch.nn.Linear(dim, dim, bias=False)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
x = self.patch_embedding(x)
|
||||
grid_size = x.shape[2:]
|
||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||
return x, grid_size # x, grid_size: (f, h, w)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||
return rearrange(
|
||||
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
||||
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
controlnet_conditioning: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x = x + self.controlnet_conv_in(controlnet_conditioning)
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
res_stack = []
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
res_stack.append(x)
|
||||
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
return controlnet_res_stack
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanControlNetModelStateDictConverter()
|
||||
|
||||
|
||||
class WanControlNetModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_base_model(self, state_dict):
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
state_dict_ = {}
|
||||
dtype, device = None, None
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("head."):
|
||||
continue
|
||||
state_dict_[name] = param
|
||||
dtype, device = param.dtype, param.device
|
||||
for block_id in range(config["num_layers"]):
|
||||
zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device)
|
||||
state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone()
|
||||
state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device)
|
||||
state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device)
|
||||
return state_dict_, config
|
||||
File diff suppressed because it is too large
Load Diff
@@ -228,7 +228,7 @@ class QuickGELU(nn.Module):
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
return super().forward(x).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -256,15 +256,11 @@ class SelfAttention(nn.Module):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# compute attention
|
||||
p = self.attn_dropout if self.training else 0.0
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||
x = x.reshape(b, s, c)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
@@ -371,11 +367,11 @@ class AttentionPool(nn.Module):
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
|
||||
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, version=2)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||
x = x.reshape(b, 1, c)
|
||||
|
||||
# output
|
||||
@@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module):
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||
videos = videos.to(dtype)
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
return out
|
||||
|
||||
|
||||
27
diffsynth/models/wan_video_motion_controller.py
Normal file
27
diffsynth/models/wan_video_motion_controller.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .wan_video_dit import sinusoidal_embedding_1d
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModel(torch.nn.Module):
|
||||
def __init__(self, freq_dim=256, dim=1536):
|
||||
super().__init__()
|
||||
self.freq_dim = freq_dim
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6),
|
||||
)
|
||||
|
||||
def forward(self, motion_bucket_id):
|
||||
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
||||
emb = self.linear(emb)
|
||||
return emb
|
||||
|
||||
def init(self):
|
||||
state_dict = self.linear[-1].state_dict()
|
||||
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||
self.linear[-1].load_state_dict(state_dict)
|
||||
@@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module):
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float().clamp_(-1, 1)
|
||||
values = values.clamp_(-1, 1)
|
||||
return values
|
||||
|
||||
|
||||
@@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module):
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float()
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, self.scale)
|
||||
return x.float()
|
||||
return x
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, self.scale)
|
||||
return video.float().clamp_(-1, 1)
|
||||
return video.clamp_(-1, 1)
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
|
||||
@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
pipe.enable_vram_management()
|
||||
return pipe
|
||||
|
||||
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
|
||||
num_patches = round((base_size / patch_size)**2)
|
||||
assert max_ratio >= 1.0
|
||||
crop_size_list = []
|
||||
wp, hp = num_patches, 1
|
||||
while wp > 0:
|
||||
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
||||
crop_size_list.append((wp * patch_size, hp * patch_size))
|
||||
if (hp + 1) * wp <= num_patches:
|
||||
hp += 1
|
||||
else:
|
||||
wp -= 1
|
||||
return crop_size_list
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||
|
||||
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
|
||||
aspect_ratio = float(height) / float(width)
|
||||
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
|
||||
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||
return buckets[closest_ratio_id], float(closest_ratio)
|
||||
|
||||
|
||||
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
|
||||
if i2v_resolution == "720p":
|
||||
bucket_hw_base_size = 960
|
||||
elif i2v_resolution == "540p":
|
||||
bucket_hw_base_size = 720
|
||||
elif i2v_resolution == "360p":
|
||||
bucket_hw_base_size = 480
|
||||
else:
|
||||
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
||||
origin_size = semantic_images[0].size
|
||||
|
||||
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
|
||||
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
|
||||
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
||||
ref_image_transform = transforms.Compose([
|
||||
transforms.Resize(closest_size),
|
||||
transforms.CenterCrop(closest_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])
|
||||
])
|
||||
|
||||
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
||||
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
||||
target_height, target_width = closest_size
|
||||
return semantic_image_pixel_values, target_height, target_width
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
|
||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||
|
||||
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
input_images=None,
|
||||
i2v_resolution="720p",
|
||||
i2v_stability=True,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device=None,
|
||||
@@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# encoder input images
|
||||
if input_images is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
|
||||
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
|
||||
image_latents = self.vae_encoder(image_pixel_values)
|
||||
|
||||
# Initialize noise
|
||||
rand_device = self.device if rand_device is None else rand_device
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
elif input_images is not None and i2v_stability:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
|
||||
t = torch.tensor([0.999]).to(device=self.device)
|
||||
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
|
||||
latents = latents.to(dtype=image_latents.dtype)
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
# current mllm does not support vram_management
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
forward_func = lets_dance_hunyuan_video
|
||||
if input_images is not None:
|
||||
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
|
||||
forward_func = lets_dance_hunyuan_video_i2v
|
||||
|
||||
# Inference
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
if input_images is not None:
|
||||
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
|
||||
latents = torch.concat([image_latents, latents], dim=2)
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
@@ -194,7 +267,7 @@ class TeaCache:
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
else:
|
||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
@@ -203,14 +276,14 @@ class TeaCache:
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = img.clone()
|
||||
return not should_calc
|
||||
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
img = x[:, :-256]
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
img = dit.final_layer(img, vec)
|
||||
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def lets_dance_hunyuan_video_i2v(
|
||||
dit: HunyuanVideoDiT,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
# Uncomment below to keep same as official implementation
|
||||
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
|
||||
vec = dit.time_in(t, dtype=torch.bfloat16)
|
||||
vec_2 = dit.vector_in(pooled_prompt_emb)
|
||||
vec = vec + vec_2
|
||||
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
|
||||
|
||||
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
|
||||
tr_token = (H // 2) * (W // 2)
|
||||
token_replace_vec = token_replace_vec + vec_2
|
||||
|
||||
img = dit.img_in(x)
|
||||
txt = dit.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, img, vec)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if tea_cache_update:
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
|
||||
@@ -11,11 +11,14 @@ from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_controlnet import WanControlNetModel
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +32,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
||||
self.controlnet: WanControlNetModel = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet', 'motion_controller']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
|
||||
@@ -60,8 +65,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
WanLayerNorm: AutoWrappedModule,
|
||||
WanRMSNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -116,7 +120,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
@@ -153,17 +157,21 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def encode_image(self, image, num_frames, height, width):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
return {"clip_fea": clip_context, "y": [y]}
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
@@ -174,19 +182,27 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
||||
return {}
|
||||
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return frames
|
||||
|
||||
|
||||
def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"controlnet_conditioning": controlnet_conditioning}
|
||||
|
||||
|
||||
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -205,9 +221,13 @@ class WanVideoPipeline(BasePipeline):
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
motion_bucket_id=None,
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="",
|
||||
controlnet_frames=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
@@ -221,15 +241,16 @@ class WanVideoPipeline(BasePipeline):
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
|
||||
# Initialize noise
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
if input_video is not None:
|
||||
self.load_models_to_device(['vae'])
|
||||
input_video = self.preprocess_images(input_video)
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
||||
input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise
|
||||
@@ -247,25 +268,53 @@ class WanVideoPipeline(BasePipeline):
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
# ControlNet
|
||||
if self.controlnet is not None and controlnet_frames is not None:
|
||||
self.load_models_to_device(['vae', 'controlnet'])
|
||||
controlnet_frames = self.preprocess_images(controlnet_frames)
|
||||
controlnet_frames = torch.stack(controlnet_frames, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
controlnet_kwargs = self.prepare_controlnet(controlnet_frames)
|
||||
else:
|
||||
controlnet_kwargs = {}
|
||||
|
||||
# Motion Controller
|
||||
if self.motion_controller is not None and motion_bucket_id is not None:
|
||||
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
|
||||
else:
|
||||
motion_kwargs = {}
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
||||
self.load_models_to_device(["dit", "controlnet", "motion_controller"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(
|
||||
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **controlnet_kwargs, **motion_kwargs,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **controlnet_kwargs, **motion_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -274,3 +323,145 @@ class WanVideoPipeline(BasePipeline):
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
self.coefficients_dict = {
|
||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||
}
|
||||
if model_id not in self.coefficients_dict:
|
||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||
self.coefficients = self.coefficients_dict[model_id]
|
||||
|
||||
def check(self, dit: WanModel, x, t_mod):
|
||||
modulated_inp = t_mod.clone()
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = self.coefficients
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = x.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
controlnet: WanControlNetModel = None,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
x: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
tea_cache: TeaCache = None,
|
||||
controlnet_conditioning: Optional[torch.Tensor] = None,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_conditioning is not None:
|
||||
controlnet_res_stack = controlnet(
|
||||
x, timestep=timestep, context=context, clip_feature=clip_feature, y=y,
|
||||
controlnet_conditioning=controlnet_conditioning,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
else:
|
||||
controlnet_res_stack = None
|
||||
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
if motion_bucket_id is not None and motion_controller is not None:
|
||||
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
if dit.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
# blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if dit.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if controlnet_res_stack is not None:
|
||||
x = x + controlnet_res_stack[block_id]
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
x = dit.head(x, t)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
|
||||
import os, torch
|
||||
from typing import Union
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE = {
|
||||
"dit-llm-encode": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE,
|
||||
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||
"crop_start": 95,
|
||||
},
|
||||
"dit-llm-encode-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_I2V,
|
||||
"crop_start": 36,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
"dit-llm-encode-video-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||
"crop_start": 103,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
}
|
||||
|
||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||
@@ -56,9 +91,20 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
def fetch_models(self,
|
||||
text_encoder_1: SD3TextEncoder1 = None,
|
||||
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
|
||||
# processor
|
||||
# TODO: may need to replace processor with local implementation
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
|
||||
# template
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
|
||||
|
||||
def apply_text_to_template(self, text, template):
|
||||
assert isinstance(template, str)
|
||||
@@ -107,8 +153,89 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt_using_mllm(self,
|
||||
prompt,
|
||||
images,
|
||||
max_length,
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True,
|
||||
image_embed_interleave=4):
|
||||
image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
hidden_state_skip_layer=hidden_state_skip_layer,
|
||||
pixel_values=image_outputs)
|
||||
|
||||
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
|
||||
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
|
||||
batch_indices, last_double_return_token_indices = torch.where(
|
||||
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
|
||||
if last_double_return_token_indices.shape[0] == 3:
|
||||
# in case the prompt is too long
|
||||
last_double_return_token_indices = torch.cat((
|
||||
last_double_return_token_indices,
|
||||
torch.tensor([input_ids.shape[-1]]),
|
||||
))
|
||||
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
|
||||
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
|
||||
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
|
||||
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
|
||||
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
|
||||
attention_mask_assistant_crop_end = last_double_return_token_indices
|
||||
text_last_hidden_state = []
|
||||
text_attention_mask = []
|
||||
image_last_hidden_state = []
|
||||
image_attention_mask = []
|
||||
for i in range(input_ids.shape[0]):
|
||||
text_last_hidden_state.append(
|
||||
torch.cat([
|
||||
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
|
||||
last_hidden_state[i, assistant_crop_end[i].item():],
|
||||
]))
|
||||
text_attention_mask.append(
|
||||
torch.cat([
|
||||
attention_mask[
|
||||
i,
|
||||
crop_start:attention_mask_assistant_crop_start[i].item(),
|
||||
],
|
||||
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
|
||||
]) if use_attention_mask else None)
|
||||
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
|
||||
image_attention_mask.append(
|
||||
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
|
||||
to(attention_mask.dtype) if use_attention_mask else None)
|
||||
|
||||
text_last_hidden_state = torch.stack(text_last_hidden_state)
|
||||
text_attention_mask = torch.stack(text_attention_mask)
|
||||
image_last_hidden_state = torch.stack(image_last_hidden_state)
|
||||
image_attention_mask = torch.stack(image_attention_mask)
|
||||
|
||||
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
|
||||
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
|
||||
|
||||
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
|
||||
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
|
||||
|
||||
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
|
||||
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
images=None,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
clip_sequence_length=77,
|
||||
@@ -116,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True):
|
||||
use_attention_mask=True,
|
||||
image_embed_interleave=4):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
@@ -136,8 +264,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
if images is None:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
else:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
|
||||
crop_start, hidden_state_skip_layer, use_attention_mask,
|
||||
image_embed_interleave)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_emb = self.text_encoder(ids, mask)
|
||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
||||
for i, v in enumerate(seq_lens):
|
||||
prompt_emb[:, v:] = 0
|
||||
return prompt_emb
|
||||
|
||||
@@ -37,7 +37,7 @@ class FlowMatchScheduler():
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
{
|
||||
"_valid_processor_keys": [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"resample",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format"
|
||||
],
|
||||
"crop_size": {
|
||||
"height": 336,
|
||||
"width": 336
|
||||
},
|
||||
"do_center_crop": true,
|
||||
"do_convert_rgb": true,
|
||||
"do_normalize": true,
|
||||
"do_rescale": true,
|
||||
"do_resize": true,
|
||||
"image_mean": [
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073
|
||||
],
|
||||
"image_processor_type": "CLIPImageProcessor",
|
||||
"image_std": [
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711
|
||||
],
|
||||
"processor_class": "LlavaProcessor",
|
||||
"resample": 3,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"size": {
|
||||
"shortest_edge": 336
|
||||
}
|
||||
}
|
||||
@@ -290,7 +290,7 @@ def launch_training_task(model, args):
|
||||
name="diffsynth_studio",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=args.output_path,
|
||||
logdir=os.path.join(args.output_path, "swanlog"),
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
|
||||
@@ -6,7 +6,7 @@ We propose EliGen, a novel approach that leverages fine-grained entity-level inf
|
||||
|
||||
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
||||
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
@@ -77,6 +77,11 @@ Demonstration of the styled entity control results with EliGen and IP-Adapter, s
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
We also provide a demo of the styled entity control results with EliGen and specific styled lora, see [./styled_entity_control.py](./styled_entity_control.py) for details. Here is the visualization of EliGen with [Lego dreambooth lora](https://huggingface.co/merve/flux-lego-lora-dreambooth).
|
||||
|||||
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
### Entity Transfer
|
||||
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.
|
||||
|
||||
|
||||
@@ -27,11 +27,20 @@ def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
|
||||
# download and load model
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
# set download_from_modelscope = False if you want to download model from huggingface
|
||||
download_from_modelscope = True
|
||||
if download_from_modelscope:
|
||||
model_id = "DiffSynth-Studio/Eligen"
|
||||
downloading_priority = ["ModelScope"]
|
||||
else:
|
||||
model_id = "modelscope/EliGen"
|
||||
downloading_priority = ["HuggingFace"]
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
model_id=model_id,
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
local_dir="models/lora/entity_control",
|
||||
downloading_priority=downloading_priority
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
|
||||
90
examples/EntityControl/styled_entity_control.py
Normal file
90
examples/EntityControl/styled_entity_control.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
from examples.EntityControl.utils import visualize_masks
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"styled_eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
# download and load model
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="FluxLora/merve-flux-lego-lora-dreambooth",
|
||||
origin_file_path="pytorch_lora_weights.safetensors",
|
||||
local_dir="models/lora/merve-flux-lego-lora-dreambooth"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# example 1
|
||||
trigger_word = "lego set in style of TOK, "
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 6
|
||||
global_prompt = "Snow White and the 6 Dwarfs."
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||
global_prompt = trigger_word + global_prompt
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
@@ -8,6 +8,12 @@
|
||||
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
||||
|
||||
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|
||||
|VRAM required|Example script|Frames|Resolution|Note|
|
||||
|-|-|-|-|-|
|
||||
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|
||||
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|
||||
## Gallery
|
||||
|
||||
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
||||
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
||||
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||
|
||||
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
||||
|
||||
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
|
||||
|
||||
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a
|
||||
|
||||
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
43
examples/HunyuanVideo/hunyuanvideo_i2v_24G.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
download_models(["HunyuanVideoI2V"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=True)
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||
|
||||
i2v_resolution = "720p"
|
||||
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)
|
||||
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
45
examples/HunyuanVideo/hunyuanvideo_i2v_80G.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
download_models(["HunyuanVideoI2V"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideoI2V/text_encoder_2",
|
||||
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda"
|
||||
)
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=False)
|
||||
# Although you have enough VRAM, we still recommend you to enable offload.
|
||||
pipe.enable_cpu_offload()
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/hunyuanvideo/*")
|
||||
|
||||
i2v_resolution = "720p"
|
||||
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
|
||||
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
|
||||
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
|
||||
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)
|
||||
@@ -31,6 +31,8 @@ Put sunglasses on the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
|
||||
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
|
||||
|
||||
### Wan-Video-14B-T2V
|
||||
|
||||
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
@@ -47,6 +49,8 @@ We present a detailed table here. The model is tested on a single A100.
|
||||
|
||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||
|
||||
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
||||
|
||||
### Wan-Video-14B-I2V
|
||||
|
||||
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
||||
@@ -155,6 +159,12 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
||||
--use_gradient_checkpointing
|
||||
```
|
||||
|
||||
If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`.
|
||||
|
||||
If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`.
|
||||
|
||||
For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`.
|
||||
|
||||
Step 5: Test
|
||||
|
||||
Test LoRA:
|
||||
|
||||
@@ -7,13 +7,17 @@ from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
if os.path.exists(os.path.join(base_path, "train")):
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
else:
|
||||
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
|
||||
self.max_num_frames = max_num_frames
|
||||
@@ -21,6 +25,8 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
self.num_frames = num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.is_i2v = is_i2v
|
||||
self.target_fps = target_fps
|
||||
|
||||
self.frame_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
@@ -48,10 +54,13 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
return None
|
||||
|
||||
frames = []
|
||||
first_frame = None
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame)
|
||||
if first_frame is None:
|
||||
first_frame = np.array(frame)
|
||||
frame = frame_process(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
@@ -59,18 +68,28 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
frames = torch.stack(frames, dim=0)
|
||||
frames = rearrange(frames, "T C H W -> C T H W")
|
||||
|
||||
return frames
|
||||
if self.is_i2v:
|
||||
return frames, first_frame
|
||||
else:
|
||||
return frames
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
|
||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
|
||||
start_frame_id = 0
|
||||
if self.target_fps is None:
|
||||
frame_interval = self.frame_interval
|
||||
else:
|
||||
reader = imageio.get_reader(file_path)
|
||||
fps = reader.get_meta_data()["fps"]
|
||||
reader.close()
|
||||
frame_interval = max(round(fps / self.target_fps), 1)
|
||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||
return frames
|
||||
|
||||
|
||||
def is_image(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
if file_ext_name.lower() in ["jpg", "png", "webp"]:
|
||||
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -78,6 +97,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def load_image(self, file_path):
|
||||
frame = Image.open(file_path).convert("RGB")
|
||||
frame = self.crop_and_resize(frame)
|
||||
first_frame = frame
|
||||
frame = self.frame_process(frame)
|
||||
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||
return frame
|
||||
@@ -86,11 +106,20 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, data_id):
|
||||
text = self.text[data_id]
|
||||
path = self.path[data_id]
|
||||
if self.is_image(path):
|
||||
video = self.load_image(path)
|
||||
else:
|
||||
video = self.load_video(path)
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
try:
|
||||
if self.is_image(path):
|
||||
if self.is_i2v:
|
||||
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||
video = self.load_image(path)
|
||||
else:
|
||||
video = self.load_video(path)
|
||||
if self.is_i2v:
|
||||
video, first_frame = video
|
||||
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||
else:
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
except:
|
||||
data = None
|
||||
return data
|
||||
|
||||
|
||||
@@ -100,43 +129,82 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class LightningModelForDataProcess(pl.LightningModule):
|
||||
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
|
||||
super().__init__()
|
||||
model_path = [text_encoder_path, vae_path]
|
||||
if image_encoder_path is not None:
|
||||
model_path.append(image_encoder_path)
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([text_encoder_path, vae_path])
|
||||
model_manager.load_models(model_path)
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
self.redirected_tensor_path = redirected_tensor_path
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
||||
data = batch[0]
|
||||
if data is None or data["video"] is None:
|
||||
return
|
||||
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
|
||||
|
||||
self.pipe.device = self.device
|
||||
if video is not None:
|
||||
# prompt
|
||||
prompt_emb = self.pipe.encode_prompt(text)
|
||||
# video
|
||||
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb}
|
||||
# image
|
||||
if "first_frame" in batch:
|
||||
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||
_, _, num_frames, height, width = video.shape
|
||||
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||
if self.redirected_tensor_path is not None:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(self.redirected_tensor_path, path)
|
||||
torch.save(data, path + ".tensors.pth")
|
||||
|
||||
|
||||
|
||||
class TensorDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, steps_per_epoch):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
if redirected_tensor_path is None:
|
||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||
else:
|
||||
cached_path = []
|
||||
for path in self.path:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(redirected_tensor_path, path)
|
||||
if os.path.exists(path + ".tensors.pth"):
|
||||
cached_path.append(path + ".tensors.pth")
|
||||
self.path = cached_path
|
||||
else:
|
||||
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||
print(len(self.path), "tensors cached in metadata.")
|
||||
assert len(self.path) > 0
|
||||
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.redirected_tensor_path = redirected_tensor_path
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
path = self.path[data_id]
|
||||
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||
return data
|
||||
while True:
|
||||
try:
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
path = self.path[data_id]
|
||||
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||
return data
|
||||
except:
|
||||
continue
|
||||
|
||||
|
||||
def __len__(self):
|
||||
@@ -145,10 +213,21 @@ class TensorDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class LightningModelForTrain(pl.LightningModule):
|
||||
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
|
||||
def __init__(
|
||||
self,
|
||||
dit_path,
|
||||
learning_rate=1e-5,
|
||||
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||
pretrained_lora_path=None
|
||||
):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([dit_path])
|
||||
if os.path.isfile(dit_path):
|
||||
model_manager.load_models([dit_path])
|
||||
else:
|
||||
dit_path = dit_path.split(",")
|
||||
model_manager.load_models([dit_path])
|
||||
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
@@ -167,6 +246,7 @@ class LightningModelForTrain(pl.LightningModule):
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
@@ -210,24 +290,30 @@ class LightningModelForTrain(pl.LightningModule):
|
||||
# Data
|
||||
latents = batch["latents"].to(self.device)
|
||||
prompt_emb = batch["prompt_emb"]
|
||||
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
|
||||
|
||||
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||
image_emb = batch["image_emb"]
|
||||
if "clip_feature" in image_emb:
|
||||
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||
if "y" in image_emb:
|
||||
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||
|
||||
# Loss
|
||||
self.pipe.device = self.device
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
noise_pred = self.pipe.denoising_model()(
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
noise_pred = self.pipe.denoising_model()(
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
@@ -276,12 +362,30 @@ def parse_args():
|
||||
default="./",
|
||||
help="Path to save the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to metadata.csv.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redirected_tensor_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save cached tensors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of text encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of image encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
@@ -336,6 +440,12 @@ def parse_args():
|
||||
default=81,
|
||||
help="Number of frames.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_fps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Expected FPS for sampling frames.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
@@ -410,6 +520,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing_offload",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing offload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_architecture",
|
||||
type=str,
|
||||
@@ -441,25 +557,30 @@ def parse_args():
|
||||
def data_process(args):
|
||||
dataset = TextVideoDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv"),
|
||||
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||
max_num_frames=args.num_frames,
|
||||
frame_interval=1,
|
||||
num_frames=args.num_frames,
|
||||
height=args.height,
|
||||
width=args.width
|
||||
width=args.width,
|
||||
is_i2v=args.image_encoder_path is not None,
|
||||
target_fps=args.target_fps,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=args.dataloader_num_workers
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
model = LightningModelForDataProcess(
|
||||
text_encoder_path=args.text_encoder_path,
|
||||
image_encoder_path=args.image_encoder_path,
|
||||
vae_path=args.vae_path,
|
||||
tiled=args.tiled,
|
||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||
redirected_tensor_path=args.redirected_tensor_path,
|
||||
)
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
@@ -472,8 +593,9 @@ def data_process(args):
|
||||
def train(args):
|
||||
dataset = TensorDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv"),
|
||||
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||
steps_per_epoch=args.steps_per_epoch,
|
||||
redirected_tensor_path=args.redirected_tensor_path,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
@@ -490,6 +612,7 @@ def train(args):
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
pretrained_lora_path=args.pretrained_lora_path,
|
||||
)
|
||||
if args.use_swanlab:
|
||||
@@ -501,7 +624,7 @@ def train(args):
|
||||
name="wan",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=args.output_path,
|
||||
logdir=os.path.join(args.output_path, "swanlog"),
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
@@ -510,6 +633,7 @@ def train(args):
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision="bf16",
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
|
||||
626
examples/wanvideo/train_wan_t2v_controlnet.py
Normal file
626
examples/wanvideo/train_wan_t2v_controlnet.py
Normal file
@@ -0,0 +1,626 @@
|
||||
import torch, os, imageio, argparse
|
||||
from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
import lightning as pl
|
||||
import pandas as pd
|
||||
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from diffsynth.models.wan_video_controlnet import WanControlNetModel
|
||||
from diffsynth.pipelines.wan_video import model_fn_wan_video
|
||||
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.controlnet_path = [os.path.join(base_path, file_name) for file_name in metadata["controlnet_file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
|
||||
self.max_num_frames = max_num_frames
|
||||
self.frame_interval = frame_interval
|
||||
self.num_frames = num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.is_i2v = is_i2v
|
||||
self.target_fps = target_fps
|
||||
|
||||
self.frame_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
v2.Resize(size=(height, width), antialias=True),
|
||||
v2.ToTensor(),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
def crop_and_resize(self, image):
|
||||
width, height = image.size
|
||||
scale = max(self.width / width, self.height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||
reader = imageio.get_reader(file_path)
|
||||
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||
reader.close()
|
||||
return None
|
||||
|
||||
frames = []
|
||||
first_frame = None
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame)
|
||||
if first_frame is None:
|
||||
first_frame = np.array(frame)
|
||||
frame = frame_process(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
|
||||
frames = torch.stack(frames, dim=0)
|
||||
frames = rearrange(frames, "T C H W -> C T H W")
|
||||
|
||||
if self.is_i2v:
|
||||
return frames, first_frame
|
||||
else:
|
||||
return frames
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
start_frame_id = 0
|
||||
if self.target_fps is None:
|
||||
frame_interval = self.frame_interval
|
||||
else:
|
||||
reader = imageio.get_reader(file_path)
|
||||
fps = reader.get_meta_data()["fps"]
|
||||
reader.close()
|
||||
frame_interval = max(round(fps / self.target_fps), 1)
|
||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||
return frames
|
||||
|
||||
|
||||
def is_image(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_image(self, file_path):
|
||||
frame = Image.open(file_path).convert("RGB")
|
||||
frame = self.crop_and_resize(frame)
|
||||
frame = self.frame_process(frame)
|
||||
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||
return frame
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
text = self.text[data_id]
|
||||
path = self.path[data_id]
|
||||
controlnet_path = self.controlnet_path[data_id]
|
||||
try:
|
||||
if self.is_image(path):
|
||||
if self.is_i2v:
|
||||
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||
video = self.load_image(path)
|
||||
else:
|
||||
video = self.load_video(path)
|
||||
controlnet_frames = self.load_video(controlnet_path)
|
||||
if self.is_i2v:
|
||||
video, first_frame = video
|
||||
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||
else:
|
||||
data = {"text": text, "video": video, "path": path, "controlnet_frames": controlnet_frames}
|
||||
except:
|
||||
data = None
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path)
|
||||
|
||||
|
||||
|
||||
class LightningModelForDataProcess(pl.LightningModule):
|
||||
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
|
||||
super().__init__()
|
||||
model_path = [text_encoder_path, vae_path]
|
||||
if image_encoder_path is not None:
|
||||
model_path.append(image_encoder_path)
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models(model_path)
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
self.redirected_tensor_path = redirected_tensor_path
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
data = batch[0]
|
||||
if data is None or data["video"] is None:
|
||||
return
|
||||
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
|
||||
controlnet_frames = data["controlnet_frames"].unsqueeze(0)
|
||||
|
||||
self.pipe.device = self.device
|
||||
if video is not None:
|
||||
# prompt
|
||||
prompt_emb = self.pipe.encode_prompt(text)
|
||||
# video
|
||||
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||
# ControlNet video
|
||||
controlnet_frames = controlnet_frames.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
controlnet_kwargs = self.pipe.prepare_controlnet(controlnet_frames, **self.tiler_kwargs)
|
||||
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"][0]
|
||||
# image
|
||||
if "first_frame" in batch:
|
||||
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||
_, _, num_frames, height, width = video.shape
|
||||
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb, "controlnet_kwargs": controlnet_kwargs}
|
||||
if self.redirected_tensor_path is not None:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(self.redirected_tensor_path, path)
|
||||
torch.save(data, path + ".tensors.pth")
|
||||
|
||||
|
||||
|
||||
class TensorDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
if redirected_tensor_path is None:
|
||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||
else:
|
||||
cached_path = []
|
||||
for path in self.path:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(redirected_tensor_path, path)
|
||||
if os.path.exists(path + ".tensors.pth"):
|
||||
cached_path.append(path + ".tensors.pth")
|
||||
self.path = cached_path
|
||||
else:
|
||||
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||
print(len(self.path), "tensors cached in metadata.")
|
||||
assert len(self.path) > 0
|
||||
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.redirected_tensor_path = redirected_tensor_path
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
while True:
|
||||
try:
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
path = self.path[data_id]
|
||||
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||
return data
|
||||
except:
|
||||
continue
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class LightningModelForTrain(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
dit_path,
|
||||
learning_rate=1e-5,
|
||||
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||
pretrained_lora_path=None
|
||||
):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
if os.path.isfile(dit_path):
|
||||
model_manager.load_models([dit_path])
|
||||
else:
|
||||
dit_path = dit_path.split(",")
|
||||
model_manager.load_models([dit_path])
|
||||
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
self.freeze_parameters()
|
||||
|
||||
state_dict = load_state_dict(dit_path, torch_dtype=torch.bfloat16)
|
||||
state_dict, config = WanControlNetModel.state_dict_converter().from_base_model(state_dict)
|
||||
self.pipe.controlnet = WanControlNetModel(**config).to(torch.bfloat16)
|
||||
self.pipe.controlnet.load_state_dict(state_dict)
|
||||
self.pipe.controlnet.train()
|
||||
self.pipe.controlnet.requires_grad_(True)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
# Freeze parameters
|
||||
self.pipe.requires_grad_(False)
|
||||
self.pipe.eval()
|
||||
self.pipe.denoising_model().train()
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
latents = batch["latents"].to(self.device)
|
||||
controlnet_kwargs = batch["controlnet_kwargs"]
|
||||
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"].to(self.device)
|
||||
prompt_emb = batch["prompt_emb"]
|
||||
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||
image_emb = batch["image_emb"]
|
||||
if "clip_feature" in image_emb:
|
||||
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||
if "y" in image_emb:
|
||||
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||
|
||||
# Loss
|
||||
self.pipe.device = self.device
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
noise_pred = model_fn_wan_video(
|
||||
dit=self.pipe.dit, controlnet=self.pipe.controlnet,
|
||||
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **controlnet_kwargs,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.controlnet.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.controlnet.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.controlnet.state_dict()
|
||||
lora_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
lora_state_dict[name] = param
|
||||
checkpoint.update(lora_state_dict)
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="data_process",
|
||||
required=True,
|
||||
choices=["data_process", "train"],
|
||||
help="Task. `data_process` or `train`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The path of the Dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./",
|
||||
help="Path to save the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to metadata.csv.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redirected_tensor_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save cached tensors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of text encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of image encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of DiT.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiled",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_size_height",
|
||||
type=int,
|
||||
default=34,
|
||||
help="Tile size (height) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_size_width",
|
||||
type=int,
|
||||
default=34,
|
||||
help="Tile size (width) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_stride_height",
|
||||
type=int,
|
||||
default=18,
|
||||
help="Tile stride (height) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_stride_width",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Tile stride (width) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps per epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_frames",
|
||||
type=int,
|
||||
default=81,
|
||||
help="Number of frames.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_fps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Expected FPS for sampling frames.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=480,
|
||||
help="Image height.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=832,
|
||||
help="Image width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="Learning rate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of batches in gradient accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default="q,k,v,o,ffn.0,ffn.2",
|
||||
help="Layers with LoRA modules.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_lora_weights",
|
||||
type=str,
|
||||
default="kaiming",
|
||||
choices=["gaussian", "kaiming"],
|
||||
help="The initializing method of LoRA weight.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training_strategy",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||
help="Training strategy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The dimension of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing_offload",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing offload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_architecture",
|
||||
type=str,
|
||||
default="lora",
|
||||
choices=["lora", "full"],
|
||||
help="Model structure to train. LoRA training or full training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_lora_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_swanlab",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use SwanLab logger.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swanlab_mode",
|
||||
default=None,
|
||||
help="SwanLab mode (cloud or local).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_process(args):
|
||||
dataset = TextVideoDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||
max_num_frames=args.num_frames,
|
||||
frame_interval=1,
|
||||
num_frames=args.num_frames,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
is_i2v=args.image_encoder_path is not None,
|
||||
target_fps=args.target_fps,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
model = LightningModelForDataProcess(
|
||||
text_encoder_path=args.text_encoder_path,
|
||||
image_encoder_path=args.image_encoder_path,
|
||||
vae_path=args.vae_path,
|
||||
tiled=args.tiled,
|
||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||
redirected_tensor_path=args.redirected_tensor_path,
|
||||
)
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
default_root_dir=args.output_path,
|
||||
)
|
||||
trainer.test(model, dataloader)
|
||||
|
||||
|
||||
def train(args):
|
||||
dataset = TensorDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||
steps_per_epoch=args.steps_per_epoch,
|
||||
redirected_tensor_path=args.redirected_tensor_path,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=1,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
model = LightningModelForTrain(
|
||||
dit_path=args.dit_path,
|
||||
learning_rate=args.learning_rate,
|
||||
train_architecture=args.train_architecture,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
pretrained_lora_path=args.pretrained_lora_path,
|
||||
)
|
||||
if args.use_swanlab:
|
||||
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||
swanlab_config.update(vars(args))
|
||||
swanlab_logger = SwanLabLogger(
|
||||
project="wan",
|
||||
name="wan",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=os.path.join(args.output_path, "swanlog"),
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
logger = None
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision="bf16",
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model, dataloader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
if args.task == "data_process":
|
||||
data_process(args)
|
||||
elif args.task == "train":
|
||||
train(args)
|
||||
691
examples/wanvideo/train_wan_t2v_motion.py
Normal file
691
examples/wanvideo/train_wan_t2v_motion.py
Normal file
@@ -0,0 +1,691 @@
|
||||
import torch, os, imageio, argparse
|
||||
from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
import lightning as pl
|
||||
import pandas as pd
|
||||
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from diffsynth.pipelines.wan_video import model_fn_wan_video
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
|
||||
self.max_num_frames = max_num_frames
|
||||
self.frame_interval = frame_interval
|
||||
self.num_frames = num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.is_i2v = is_i2v
|
||||
self.target_fps = target_fps
|
||||
|
||||
self.frame_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
v2.Resize(size=(height, width), antialias=True),
|
||||
v2.ToTensor(),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
def crop_and_resize(self, image):
|
||||
width, height = image.size
|
||||
scale = max(self.width / width, self.height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||
reader = imageio.get_reader(file_path)
|
||||
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||
reader.close()
|
||||
return None
|
||||
|
||||
frames = []
|
||||
first_frame = None
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame)
|
||||
if first_frame is None:
|
||||
first_frame = np.array(frame)
|
||||
frame = frame_process(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
|
||||
frames = torch.stack(frames, dim=0)
|
||||
frames = rearrange(frames, "T C H W -> C T H W")
|
||||
|
||||
if self.is_i2v:
|
||||
return frames, first_frame
|
||||
else:
|
||||
return frames
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
start_frame_id = 0
|
||||
if self.target_fps is None:
|
||||
frame_interval = self.frame_interval
|
||||
else:
|
||||
reader = imageio.get_reader(file_path)
|
||||
fps = reader.get_meta_data()["fps"]
|
||||
reader.close()
|
||||
frame_interval = max(round(fps / self.target_fps), 1)
|
||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||
return frames
|
||||
|
||||
|
||||
def is_image(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_image(self, file_path):
|
||||
frame = Image.open(file_path).convert("RGB")
|
||||
frame = self.crop_and_resize(frame)
|
||||
first_frame = frame
|
||||
frame = self.frame_process(frame)
|
||||
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||
return frame
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
text = self.text[data_id]
|
||||
path = self.path[data_id]
|
||||
try:
|
||||
if self.is_image(path):
|
||||
if self.is_i2v:
|
||||
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||
video = self.load_image(path)
|
||||
else:
|
||||
video = self.load_video(path)
|
||||
if self.is_i2v:
|
||||
video, first_frame = video
|
||||
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||
else:
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
except:
|
||||
data = None
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path)
|
||||
|
||||
|
||||
|
||||
class LightningModelForDataProcess(pl.LightningModule):
|
||||
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
|
||||
super().__init__()
|
||||
model_path = [text_encoder_path, vae_path]
|
||||
if image_encoder_path is not None:
|
||||
model_path.append(image_encoder_path)
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models(model_path)
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
self.redirected_tensor_path = redirected_tensor_path
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
data = batch[0]
|
||||
if data is None or data["video"] is None:
|
||||
return
|
||||
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
|
||||
|
||||
self.pipe.device = self.device
|
||||
if video is not None:
|
||||
# prompt
|
||||
prompt_emb = self.pipe.encode_prompt(text)
|
||||
# video
|
||||
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||
# image
|
||||
if "first_frame" in batch:
|
||||
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||
_, _, num_frames, height, width = video.shape
|
||||
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||
if self.redirected_tensor_path is not None:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(self.redirected_tensor_path, path)
|
||||
torch.save(data, path + ".tensors.pth")
|
||||
|
||||
|
||||
|
||||
class TensorDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
if redirected_tensor_path is None:
|
||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||
else:
|
||||
cached_path = []
|
||||
for path in self.path:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(redirected_tensor_path, path)
|
||||
if os.path.exists(path + ".tensors.pth"):
|
||||
cached_path.append(path + ".tensors.pth")
|
||||
self.path = cached_path
|
||||
else:
|
||||
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||
print(len(self.path), "tensors cached in metadata.")
|
||||
assert len(self.path) > 0
|
||||
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.redirected_tensor_path = redirected_tensor_path
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
while True:
|
||||
try:
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
path = self.path[data_id]
|
||||
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||
return data
|
||||
except:
|
||||
continue
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class LightningModelForTrain(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
dit_path,
|
||||
learning_rate=1e-5,
|
||||
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||
pretrained_lora_path=None
|
||||
):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
if os.path.isfile(dit_path):
|
||||
model_manager.load_models([dit_path])
|
||||
else:
|
||||
dit_path = dit_path.split(",")
|
||||
model_manager.load_models([dit_path])
|
||||
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
self.freeze_parameters()
|
||||
|
||||
self.pipe.motion_controller = WanMotionControllerModel().to(torch.bfloat16)
|
||||
self.pipe.motion_controller.init()
|
||||
self.pipe.motion_controller.requires_grad_(True)
|
||||
self.pipe.motion_controller.train()
|
||||
self.motion_bucket_manager = MotionBucketManager()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
# Freeze parameters
|
||||
self.pipe.requires_grad_(False)
|
||||
self.pipe.eval()
|
||||
self.pipe.dit.train()
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
|
||||
# Add LoRA to UNet
|
||||
self.lora_alpha = lora_alpha
|
||||
if init_lora_weights == "kaiming":
|
||||
init_lora_weights = True
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights=init_lora_weights,
|
||||
target_modules=lora_target_modules.split(","),
|
||||
)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
for param in model.parameters():
|
||||
# Upcast LoRA parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# Lora pretrained lora weights
|
||||
if pretrained_lora_path is not None:
|
||||
state_dict = load_state_dict(pretrained_lora_path)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
all_keys = [i for i, _ in model.named_parameters()]
|
||||
num_updated_keys = len(all_keys) - len(missing_keys)
|
||||
num_unexpected_keys = len(unexpected_keys)
|
||||
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
latents = batch["latents"].to(self.device)
|
||||
prompt_emb = batch["prompt_emb"]
|
||||
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||
image_emb = batch["image_emb"]
|
||||
if "clip_feature" in image_emb:
|
||||
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||
if "y" in image_emb:
|
||||
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||
|
||||
# Loss
|
||||
self.pipe.device = self.device
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
motion_bucket_id = self.motion_bucket_manager(latents)
|
||||
motion_bucket_kwargs = self.pipe.prepare_motion_bucket_id(motion_bucket_id)
|
||||
|
||||
# Compute loss
|
||||
noise_pred = model_fn_wan_video(
|
||||
dit=self.pipe.dit, motion_controller=self.pipe.motion_controller,
|
||||
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **motion_bucket_kwargs,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.motion_controller.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_controller.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.motion_controller.state_dict()
|
||||
lora_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
lora_state_dict[name] = param
|
||||
checkpoint.update(lora_state_dict)
|
||||
|
||||
|
||||
|
||||
class MotionBucketManager:
|
||||
def __init__(self):
|
||||
self.thresholds = [
|
||||
0.093750000, 0.094726562, 0.100585938, 0.100585938, 0.108886719, 0.109375000, 0.118652344, 0.127929688, 0.127929688, 0.130859375,
|
||||
0.133789062, 0.137695312, 0.138671875, 0.138671875, 0.139648438, 0.143554688, 0.143554688, 0.147460938, 0.149414062, 0.149414062,
|
||||
0.152343750, 0.153320312, 0.154296875, 0.154296875, 0.157226562, 0.163085938, 0.163085938, 0.164062500, 0.165039062, 0.166992188,
|
||||
0.173828125, 0.179687500, 0.180664062, 0.184570312, 0.187500000, 0.188476562, 0.188476562, 0.189453125, 0.189453125, 0.202148438,
|
||||
0.206054688, 0.210937500, 0.210937500, 0.211914062, 0.214843750, 0.214843750, 0.216796875, 0.216796875, 0.216796875, 0.218750000,
|
||||
0.218750000, 0.221679688, 0.222656250, 0.227539062, 0.229492188, 0.230468750, 0.236328125, 0.243164062, 0.243164062, 0.245117188,
|
||||
0.253906250, 0.253906250, 0.255859375, 0.259765625, 0.275390625, 0.275390625, 0.277343750, 0.279296875, 0.279296875, 0.279296875,
|
||||
0.292968750, 0.292968750, 0.302734375, 0.306640625, 0.312500000, 0.312500000, 0.326171875, 0.330078125, 0.332031250, 0.332031250,
|
||||
0.337890625, 0.343750000, 0.343750000, 0.351562500, 0.355468750, 0.357421875, 0.361328125, 0.367187500, 0.382812500, 0.388671875,
|
||||
0.392578125, 0.392578125, 0.392578125, 0.404296875, 0.404296875, 0.425781250, 0.433593750, 0.507812500, 0.519531250, 0.539062500,
|
||||
]
|
||||
|
||||
def get_motion_score(self, frames):
|
||||
score = frames[:, :, 1:, :, :].std(dim=2).mean().tolist()
|
||||
return score
|
||||
|
||||
def get_bucket_id(self, motion_score):
|
||||
for bucket_id in range(len(self.thresholds) - 1):
|
||||
if self.thresholds[bucket_id + 1] > motion_score:
|
||||
return bucket_id
|
||||
return len(self.thresholds)
|
||||
|
||||
def __call__(self, frames):
|
||||
score = self.get_motion_score(frames)
|
||||
bucket_id = self.get_bucket_id(score)
|
||||
return bucket_id
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="data_process",
|
||||
required=True,
|
||||
choices=["data_process", "train"],
|
||||
help="Task. `data_process` or `train`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The path of the Dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./",
|
||||
help="Path to save the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to metadata.csv.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--redirected_tensor_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save cached tensors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of text encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of image encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of DiT.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiled",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_size_height",
|
||||
type=int,
|
||||
default=34,
|
||||
help="Tile size (height) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_size_width",
|
||||
type=int,
|
||||
default=34,
|
||||
help="Tile size (width) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_stride_height",
|
||||
type=int,
|
||||
default=18,
|
||||
help="Tile stride (height) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_stride_width",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Tile stride (width) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps per epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_frames",
|
||||
type=int,
|
||||
default=81,
|
||||
help="Number of frames.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_fps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Expected FPS for sampling frames.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=480,
|
||||
help="Image height.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=832,
|
||||
help="Image width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="Learning rate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of batches in gradient accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default="q,k,v,o,ffn.0,ffn.2",
|
||||
help="Layers with LoRA modules.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_lora_weights",
|
||||
type=str,
|
||||
default="kaiming",
|
||||
choices=["gaussian", "kaiming"],
|
||||
help="The initializing method of LoRA weight.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training_strategy",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||
help="Training strategy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The dimension of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing_offload",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing offload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_architecture",
|
||||
type=str,
|
||||
default="lora",
|
||||
choices=["lora", "full"],
|
||||
help="Model structure to train. LoRA training or full training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_lora_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_swanlab",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use SwanLab logger.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swanlab_mode",
|
||||
default=None,
|
||||
help="SwanLab mode (cloud or local).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_process(args):
|
||||
dataset = TextVideoDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||
max_num_frames=args.num_frames,
|
||||
frame_interval=1,
|
||||
num_frames=args.num_frames,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
is_i2v=args.image_encoder_path is not None,
|
||||
target_fps=args.target_fps,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
model = LightningModelForDataProcess(
|
||||
text_encoder_path=args.text_encoder_path,
|
||||
image_encoder_path=args.image_encoder_path,
|
||||
vae_path=args.vae_path,
|
||||
tiled=args.tiled,
|
||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||
redirected_tensor_path=args.redirected_tensor_path,
|
||||
)
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
default_root_dir=args.output_path,
|
||||
)
|
||||
trainer.test(model, dataloader)
|
||||
|
||||
|
||||
def get_motion_thresholds(dataloader):
|
||||
scores = []
|
||||
for data in tqdm(dataloader):
|
||||
scores.append(data["latents"][:, :, 1:, :, :].std(dim=2).mean().tolist())
|
||||
scores = sorted(scores)
|
||||
for i in range(100):
|
||||
s = scores[int(i/100 * len(scores))]
|
||||
print("%.9f" % s, end=", ")
|
||||
|
||||
|
||||
def train(args):
|
||||
dataset = TensorDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||
steps_per_epoch=args.steps_per_epoch,
|
||||
redirected_tensor_path=args.redirected_tensor_path,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=1,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
model = LightningModelForTrain(
|
||||
dit_path=args.dit_path,
|
||||
learning_rate=args.learning_rate,
|
||||
train_architecture=args.train_architecture,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
pretrained_lora_path=args.pretrained_lora_path,
|
||||
)
|
||||
if args.use_swanlab:
|
||||
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||
swanlab_config.update(vars(args))
|
||||
swanlab_logger = SwanLabLogger(
|
||||
project="wan",
|
||||
name="wan",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=os.path.join(args.output_path, "swanlog"),
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
logger = None
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision="bf16",
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model, dataloader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
if args.task == "data_process":
|
||||
data_process(args)
|
||||
elif args.task == "train":
|
||||
train(args)
|
||||
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True,
|
||||
# TeaCache parameters
|
||||
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
|
||||
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# TeaCache doesn't support video-to-video
|
||||
@@ -9,6 +9,10 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
|
||||
torch_dtype=torch.float32, # Image Encoder is loaded with float32
|
||||
)
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
@@ -20,14 +24,13 @@ model_manager.load_models(
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
|
||||
# Download example image
|
||||
dataset_snapshot_download(
|
||||
|
||||
125
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
125
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
import lightning as pl
|
||||
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput
|
||||
from torch.distributed._tensor import Replicate, Shard
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from lightning.pytorch.strategies import ModelParallelStrategy
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video
|
||||
from tqdm import tqdm
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
|
||||
class ToyDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, tasks=[]):
|
||||
self.tasks = tasks
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
return self.tasks[data_id]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tasks)
|
||||
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
def configure_model(self):
|
||||
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||
layer_tp_plan = {
|
||||
"self_attn": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Shard(0)),
|
||||
),
|
||||
"self_attn.q": SequenceParallel(),
|
||||
"self_attn.k": SequenceParallel(),
|
||||
"self_attn.v": SequenceParallel(),
|
||||
"self_attn.norm_q": SequenceParallel(),
|
||||
"self_attn.norm_k": SequenceParallel(),
|
||||
"self_attn.attn": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"cross_attn": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Replicate()),
|
||||
),
|
||||
"cross_attn.q": SequenceParallel(),
|
||||
"cross_attn.k": SequenceParallel(),
|
||||
"cross_attn.v": SequenceParallel(),
|
||||
"cross_attn.norm_q": SequenceParallel(),
|
||||
"cross_attn.norm_k": SequenceParallel(),
|
||||
"cross_attn.attn": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"ffn.0": ColwiseParallel(),
|
||||
"ffn.2": RowwiseParallel(),
|
||||
}
|
||||
parallelize_module(
|
||||
module=block,
|
||||
device_mesh=tp_mesh,
|
||||
parallelize_plan=layer_tp_plan,
|
||||
)
|
||||
|
||||
|
||||
def test_step(self, batch):
|
||||
data = batch[0]
|
||||
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
||||
output_path = data.pop("output_path")
|
||||
with torch.no_grad(), torch.inference_mode(False):
|
||||
video = self.pipe(**data)
|
||||
if self.local_rank == 0:
|
||||
save_video(video, output_path, fps=15, quality=5)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
ToyDataset([
|
||||
{
|
||||
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
"num_inference_steps": 50,
|
||||
"seed": 0,
|
||||
"tiled": False,
|
||||
"output_path": "video1.mp4",
|
||||
},
|
||||
{
|
||||
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
"num_inference_steps": 50,
|
||||
"seed": 1,
|
||||
"tiled": False,
|
||||
"output_path": "video2.mp4",
|
||||
},
|
||||
]),
|
||||
collate_fn=lambda x: x
|
||||
)
|
||||
model = LitModel()
|
||||
trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy())
|
||||
trainer.test(model, dataloader)
|
||||
Reference in New Issue
Block a user