support hunyuanvideo_i2v

This commit is contained in:
mi804
2025-03-11 16:20:09 +08:00
parent 945b43492e
commit 4bec2983a9
9 changed files with 327 additions and 161 deletions

View File

@@ -19,7 +19,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Wan-Video](https://github.com/Wan-Video/Wan2.1) * [Wan-Video](https://github.com/Wan-Video/Wan2.1)
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V) * [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) * [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-I2V]()
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b) * [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev) * [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
@@ -36,6 +36,7 @@ Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News ## News
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). - **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).

View File

@@ -112,7 +112,6 @@ model_loader_configs = [
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"), (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "ae3c22aaa28bfae6f3688f796c9814ae", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"), (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"), (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),

View File

@@ -237,7 +237,7 @@ class IndividualTokenRefinerBlock(torch.nn.Module):
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1) x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
return x return x
class SingleTokenRefiner(torch.nn.Module): class SingleTokenRefiner(torch.nn.Module):
def __init__(self, in_channels=4096, hidden_size=3072, depth=2): def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
@@ -270,7 +270,7 @@ class SingleTokenRefiner(torch.nn.Module):
x = block(x, c, mask) x = block(x, c, mask)
return x return x
class ModulateDiT(torch.nn.Module): class ModulateDiT(torch.nn.Module):
def __init__(self, hidden_size, factor=6): def __init__(self, hidden_size, factor=6):
@@ -280,9 +280,14 @@ class ModulateDiT(torch.nn.Module):
def forward(self, x): def forward(self, x):
return self.linear(self.act(x)) return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
if tr_shift is not None:
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
if scale is None and shift is None: if scale is None and shift is None:
return x return x
elif shift is None: elif shift is None:
@@ -291,7 +296,7 @@ def modulate(x, shift=None, scale=None):
return x + shift.unsqueeze(1) return x + shift.unsqueeze(1)
else: else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def reshape_for_broadcast( def reshape_for_broadcast(
freqs_cis, freqs_cis,
@@ -344,7 +349,7 @@ def rotate_half(x):
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2] ) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3) return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb( def apply_rotary_emb(
xq: torch.Tensor, xq: torch.Tensor,
@@ -386,6 +391,15 @@ def attention(q, k, v):
return x return x
def apply_gate(x, gate, tr_gate=None, tr_token=None):
if tr_gate is not None:
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
return torch.concat((x_zero, x_orig), dim=1)
else:
return x * gate.unsqueeze(1)
class MMDoubleStreamBlockComponent(torch.nn.Module): class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__() super().__init__()
@@ -406,11 +420,17 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size) torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
) )
def forward(self, hidden_states, conditioning, freqs_cis=None): def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1) mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
else:
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale) norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states) qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -419,15 +439,19 @@ class MMDoubleStreamBlockComponent(torch.nn.Module):
if freqs_cis is not None: if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False) q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate) def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
def process_ff(self, hidden_states, attn_output, mod):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1) if mod_tr is not None:
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1) tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
else:
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
return hidden_states return hidden_states
class MMDoubleStreamBlock(torch.nn.Module): class MMDoubleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
@@ -435,18 +459,18 @@ class MMDoubleStreamBlock(torch.nn.Module):
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio) self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis): def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis) (q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None) (q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous() v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
attn_output_a = attention(q_a, k_a, v_a) attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b) attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1) attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a) hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b) hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b return hidden_states_a, hidden_states_b
@@ -489,7 +513,7 @@ class MMSingleStreamBlockOriginal(torch.nn.Module):
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
return x + output * mod_gate.unsqueeze(1) return x + output * mod_gate.unsqueeze(1)
class MMSingleStreamBlock(torch.nn.Module): class MMSingleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4): def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
@@ -510,11 +534,17 @@ class MMSingleStreamBlock(torch.nn.Module):
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False) torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
) )
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256): def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
if token_replace_vec is not None:
assert tr_token is not None
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
else:
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
norm_hidden_states = self.norm(hidden_states) norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale) norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
qkv = self.to_qkv(norm_hidden_states) qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
@@ -526,16 +556,17 @@ class MMSingleStreamBlock(torch.nn.Module):
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False) q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous() v_len = txt_len - split_token
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous() q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous() k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
attn_output_a = attention(q_a, k_a, v_a) attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b) attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1) attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1) hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1) hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
return hidden_states return hidden_states
@@ -581,7 +612,7 @@ class HunyuanVideoDiT(torch.nn.Module):
def unpatchify(self, x, T, H, W): def unpatchify(self, x, T, H, W):
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2) x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
return x return x
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"): def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
self.warm_device = warm_device self.warm_device = warm_device
self.cold_device = cold_device self.cold_device = cold_device
@@ -616,7 +647,7 @@ class HunyuanVideoDiT(torch.nn.Module):
vec += self.guidance_in(guidance * 1000, dtype=torch.float32) vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x) img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask) txt = self.txt_in(prompt_emb, t, text_mask)
for block in tqdm(self.double_blocks, desc="Double stream blocks"): for block in tqdm(self.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
@@ -628,7 +659,7 @@ class HunyuanVideoDiT(torch.nn.Module):
img = self.final_layer(img, vec) img = self.final_layer(img, vec)
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2) img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img return img
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"): def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
def cast_to(weight, dtype=None, device=None, copy=False): def cast_to(weight, dtype=None, device=None, copy=False):
@@ -684,7 +715,7 @@ class HunyuanVideoDiT(torch.nn.Module):
del x_, weight_, bias_ del x_, weight_, bias_
torch.cuda.empty_cache() torch.cuda.empty_cache()
return y_ return y_
def block_forward(self, x, **kwargs): def block_forward(self, x, **kwargs):
# This feature can only reduce 2GB VRAM, so we disable it. # This feature can only reduce 2GB VRAM, so we disable it.
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device) y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
@@ -692,19 +723,19 @@ class HunyuanVideoDiT(torch.nn.Module):
for j in range((self.out_features + self.block_size - 1) // self.block_size): for j in range((self.out_features + self.block_size - 1) // self.block_size):
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device) y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
return y return y
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.linear(x, weight, bias) return torch.nn.functional.linear(x, weight, bias)
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, module, dtype=torch.bfloat16, device="cuda"): def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
super().__init__() super().__init__()
self.module = module self.module = module
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def forward(self, hidden_states, **kwargs): def forward(self, hidden_states, **kwargs):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
@@ -714,30 +745,30 @@ class HunyuanVideoDiT(torch.nn.Module):
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda") weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
hidden_states = hidden_states * weight hidden_states = hidden_states * weight
return hidden_states return hidden_states
class Conv3d(torch.nn.Conv3d): class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def forward(self, x): def forward(self, x):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class LayerNorm(torch.nn.LayerNorm): class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def forward(self, x): def forward(self, x):
if self.weight is not None and self.bias is not None: if self.weight is not None and self.bias is not None:
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps) return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else: else:
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
def replace_layer(model, dtype=torch.bfloat16, device="cuda"): def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
for name, module in model.named_children(): for name, module in model.named_children():
if isinstance(module, torch.nn.Linear): if isinstance(module, torch.nn.Linear):
@@ -780,7 +811,6 @@ class HunyuanVideoDiT(torch.nn.Module):
return HunyuanVideoDiTStateDictConverter() return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter: class HunyuanVideoDiTStateDictConverter:
def __init__(self): def __init__(self):
pass pass
@@ -886,6 +916,5 @@ class HunyuanVideoDiTStateDictConverter:
state_dict_[name_] = param state_dict_[name_] = param
else: else:
pass pass
if origin_hash_key == "ae3c22aaa28bfae6f3688f796c9814ae":
return state_dict_, {"in_channels": 33, "guidance_embed":False}
return state_dict_ return state_dict_

View File

@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter from ..prompters import HunyuanVideoPrompter
import torch import torch
import torchvision.transforms as transforms
from einops import rearrange from einops import rearrange
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
class HunyuanVideoPipeline(BasePipeline): class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16): def __init__(self, device="cuda", torch_dtype=torch.float16):
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
pipe.enable_vram_management() pipe.enable_vram_management()
return pipe return pipe
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
target_height, target_width = closest_size
return semantic_image_pixel_values, target_height, target_width
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt( prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
) )
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask} return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
prompt, prompt,
negative_prompt="", negative_prompt="",
input_video=None, input_video=None,
input_images=None,
i2v_resolution="720p",
i2v_stability=True,
denoising_strength=1.0, denoising_strength=1.0,
seed=None, seed=None,
rand_device=None, rand_device=None,
@@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline):
): ):
# Tiler parameters # Tiler parameters
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride} tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength) self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# encoder input images
if input_images is not None:
self.load_models_to_device(['vae_encoder'])
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
image_latents = self.vae_encoder(image_pixel_values)
# Initialize noise # Initialize noise
rand_device = self.device if rand_device is None else rand_device rand_device = self.device if rand_device is None else rand_device
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device) noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
input_video = torch.stack(input_video, dim=2) input_video = torch.stack(input_video, dim=2)
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
elif input_images is not None and i2v_stability:
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
t = torch.tensor([0.999]).to(device=self.device)
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
latents = latents.to(dtype=image_latents.dtype)
else: else:
latents = noise latents = noise
# Encode prompts # Encode prompts
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"]) # current mllm does not support vram_management
prompt_emb_posi = self.encode_prompt(prompt, positive=True) self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
if cfg_scale != 1.0: if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(self.device) timestep = timestep.unsqueeze(0).to(self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
forward_func = lets_dance_hunyuan_video
if input_images is not None:
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
forward_func = lets_dance_hunyuan_video_i2v
# Inference # Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype): with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs) noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input) noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
noise_pred = noise_pred_posi noise_pred = noise_pred_posi
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
self.load_models_to_device([] if self.vram_management else ["dit"]) self.load_models_to_device([] if self.vram_management else ["dit"])
# Scheduler # Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) if input_images is not None:
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
latents = torch.concat([image_latents, latents], dim=2)
else:
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Decode # Decode
self.load_models_to_device(['vae_decoder']) self.load_models_to_device(['vae_decoder'])
@@ -194,7 +267,7 @@ class TeaCache:
if self.step == 0 or self.step == self.num_inference_steps - 1: if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance = 0 self.accumulated_rel_l1_distance = 0
else: else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients) rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
@@ -203,14 +276,14 @@ class TeaCache:
else: else:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance = 0 self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp self.previous_modulated_input = modulated_inp
self.step += 1 self.step += 1
if self.step == self.num_inference_steps: if self.step == self.num_inference_steps:
self.step = 0 self.step = 0
if should_calc: if should_calc:
self.previous_hidden_states = img.clone() self.previous_hidden_states = img.clone()
return not should_calc return not should_calc
def store(self, hidden_states): def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None self.previous_hidden_states = None
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
print("TeaCache skip forward.") print("TeaCache skip forward.")
img = tea_cache.update(img) img = tea_cache.update(img)
else: else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"): for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
x = torch.concat([img, txt], dim=1) x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"): for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin)) x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
img = x[:, :-256] img = x[:, :-txt_len]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def lets_dance_hunyuan_video_i2v(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
# Uncomment below to keep same as official implementation
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
vec = dit.time_in(t, dtype=torch.bfloat16)
vec_2 = dit.vector_in(pooled_prompt_emb)
vec = vec + vec_2
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
tr_token = (H // 2) * (W // 2)
token_replace_vec = token_replace_vec + vec_2
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
split_token = int(text_mask.sum(dim=1))
txt_len = int(txt.shape[1])
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
img = x[:, :-txt_len]
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(img) tea_cache.store(img)

View File

@@ -87,7 +87,6 @@ class HunyuanVideoPrompter(BasePrompter):
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right') self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.i2v_mode = False
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
@@ -106,8 +105,6 @@ class HunyuanVideoPrompter(BasePrompter):
# template # template
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v'] self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
# mode setting
self.i2v_mode = True
def apply_text_to_template(self, text, template): def apply_text_to_template(self, text, template):
assert isinstance(template, str) assert isinstance(template, str)
@@ -164,10 +161,8 @@ class HunyuanVideoPrompter(BasePrompter):
crop_start, crop_start,
hidden_state_skip_layer=2, hidden_state_skip_layer=2,
use_attention_mask=True, use_attention_mask=True,
image_embed_interleave=2): image_embed_interleave=4):
image_outputs = self.processor(images, return_tensors="pt")[ image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
"pixel_values"
].to(device)
max_length += crop_start max_length += crop_start
inputs = self.tokenizer_2(prompt, inputs = self.tokenizer_2(prompt,
return_tensors="pt", return_tensors="pt",
@@ -248,7 +243,8 @@ class HunyuanVideoPrompter(BasePrompter):
data_type='video', data_type='video',
use_template=True, use_template=True,
hidden_state_skip_layer=2, hidden_state_skip_layer=2,
use_attention_mask=True): use_attention_mask=True,
image_embed_interleave=4):
prompt = self.process_prompt(prompt, positive=positive) prompt = self.process_prompt(prompt, positive=positive)
@@ -273,6 +269,7 @@ class HunyuanVideoPrompter(BasePrompter):
hidden_state_skip_layer, use_attention_mask) hidden_state_skip_layer, use_attention_mask)
else: else:
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device, prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
crop_start, hidden_state_skip_layer, use_attention_mask) crop_start, hidden_state_skip_layer, use_attention_mask,
image_embed_interleave)
return prompt_emb, pooled_prompt_emb, attention_mask return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -8,6 +8,12 @@
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)| |24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.| |6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
[HunyuanVideo-I2V](https://github.com/Tencent/HunyuanVideo-I2V) is the image-to-video generation version of HunyuanVideo. We also provide advanced VRAM management for this model.
|VRAM required|Example script|Frames|Resolution|Note|
|-|-|-|-|-|
|80G|[hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py)|129|720p|No VRAM management.|
|24G|[hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py)|129|720p|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
## Gallery ## Gallery
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py): Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
@@ -21,3 +27,7 @@ https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video): Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10 https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
Video generated by [hunyuanvideo_i2v_80G.py](hunyuanvideo_i2v_80G.py) and [hunyuanvideo_i2v_24G.py](hunyuanvideo_i2v_24G.py):
https://github.com/user-attachments/assets/494f252a-c9af-440d-84ba-a8ddcdcc538a

View File

@@ -1,88 +0,0 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from diffsynth.prompters.hunyuan_video_prompter import HunyuanVideoPrompter
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_inputs(semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2)
return semantic_image_pixel_values
model_manager = ModelManager()
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
device="cuda"
)
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cuda"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False
)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
print()

View File

@@ -0,0 +1,43 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cpu"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=True)
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}_low_vram.mp4", fps=30, quality=6)

View File

@@ -0,0 +1,45 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from modelscope import dataset_snapshot_download
from PIL import Image
download_models(["HunyuanVideoI2V"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16,
device="cuda"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cuda"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/hunyuanvideo/*")
i2v_resolution = "720p"
prompt = "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
images = [Image.open("data/examples/hunyuanvideo/0.jpg").convert('RGB')]
video = pipe(prompt, input_images=images, num_inference_steps=50, seed=0, i2v_resolution=i2v_resolution)
save_video(video, f"video_{i2v_resolution}.mp4", fps=30, quality=6)