mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Wan video (#338)
This commit is contained in:
@@ -36,7 +36,9 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
## News
|
||||
|
||||
- **February 17, 2024** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
||||
- **February 25, 2025** We support Wan-Video, a collection of video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
||||
|
||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
@@ -118,7 +120,7 @@ cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Or install from pypi:
|
||||
Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.):
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
|
||||
@@ -54,6 +54,11 @@ from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -108,6 +113,13 @@ model_loader_configs = [
|
||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
||||
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -73,7 +73,6 @@ try:
|
||||
)
|
||||
except Exception as exception:
|
||||
kernels = None
|
||||
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
||||
|
||||
|
||||
class W8A16Linear(torch.autograd.Function):
|
||||
|
||||
@@ -8,6 +8,7 @@ from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .cog_dit import CogDiT
|
||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||
from .wan_video_dit import WanModel
|
||||
|
||||
|
||||
|
||||
@@ -197,7 +198,7 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||
|
||||
@@ -69,7 +69,9 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
with init_weights_on_device():
|
||||
model= model_class(**extra_kwargs)
|
||||
model = model_class(**extra_kwargs)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
model.load_state_dict(model_state_dict, assign=True)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
|
||||
789
diffsynth/models/wan_video_dit.py
Normal file
789
diffsynth/models/wan_video_dit.py
Normal file
@@ -0,0 +1,789 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.amp as amp
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
|
||||
def flash_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
version=None,
|
||||
):
|
||||
"""
|
||||
q: [B, Lq, Nq, C1].
|
||||
k: [B, Lk, Nk, C1].
|
||||
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
||||
q_lens: [B].
|
||||
k_lens: [B].
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
||||
deterministic: bool. If True, slightly slower and uses more memory.
|
||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||
"""
|
||||
half_dtypes = (torch.float16, torch.bfloat16)
|
||||
assert dtype in half_dtypes
|
||||
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
||||
|
||||
# params
|
||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||
|
||||
def half(x):
|
||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||
|
||||
# preprocess query
|
||||
if q_lens is None:
|
||||
q = half(q.flatten(0, 1))
|
||||
q_lens = torch.tensor(
|
||||
[lq] * b, dtype=torch.int32).to(
|
||||
device=q.device, non_blocking=True)
|
||||
else:
|
||||
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
||||
|
||||
# preprocess key, value
|
||||
if k_lens is None:
|
||||
k = half(k.flatten(0, 1))
|
||||
v = half(v.flatten(0, 1))
|
||||
k_lens = torch.tensor(
|
||||
[lk] * b, dtype=torch.int32).to(
|
||||
device=k.device, non_blocking=True)
|
||||
else:
|
||||
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
||||
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
||||
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
||||
warnings.warn(
|
||||
'Flash attention 3 is not available, use flash attention 2 instead.'
|
||||
)
|
||||
|
||||
# apply attention
|
||||
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
||||
# Note: dropout_p, window_size are not supported in FA3 now.
|
||||
x = flash_attn_interface.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
||||
elif FLASH_ATTN_2_AVAILABLE:
|
||||
print(q_lens, lq, k_lens, lk, causal, window_size)
|
||||
x = flash_attn.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
||||
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic).unflatten(0, (b, lq))
|
||||
print(x.shape)
|
||||
else:
|
||||
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = x.transpose(1, 2).contiguous()
|
||||
|
||||
# output
|
||||
return x.type(out_dtype)
|
||||
|
||||
|
||||
def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
|
||||
b, lq, lk = q.size(0), q.size(1), k.size(1)
|
||||
if q_lens is None:
|
||||
q_lens = torch.tensor([lq] * b, dtype=torch.int32)
|
||||
if k_lens is None:
|
||||
k_lens = torch.tensor([lk] * b, dtype=torch.int32)
|
||||
attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
|
||||
for i in range(b):
|
||||
q_len, k_len = q_lens[i], k_lens[i]
|
||||
attn_mask[i, q_len:, :] = True
|
||||
attn_mask[i, :, k_len:] = True
|
||||
|
||||
if causal:
|
||||
causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
|
||||
attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
|
||||
|
||||
attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
|
||||
return attn_mask
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
fa_version=None,
|
||||
):
|
||||
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
||||
return flash_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
q_lens=q_lens,
|
||||
k_lens=k_lens,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
q_scale=q_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic,
|
||||
dtype=dtype,
|
||||
version=fa_version,
|
||||
)
|
||||
else:
|
||||
if q_lens is not None or k_lens is not None:
|
||||
warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.')
|
||||
attn_mask = None
|
||||
|
||||
q = q.transpose(1, 2).to(dtype)
|
||||
k = k.transpose(1, 2).to(dtype)
|
||||
v = v.transpose(1, 2).to(dtype)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
||||
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.float64)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(
|
||||
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
@amp.autocast(enabled=False, device_type="cuda")
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0 / torch.pow(theta,
|
||||
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
@amp.autocast(enabled=False, device_type="cuda")
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
n, c = x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# precompute multipliers
|
||||
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
||||
seq_len, n, -1, 2))
|
||||
freqs_i = torch.cat([
|
||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
],
|
||||
dim=-1).reshape(seq_len, 1, -1)
|
||||
|
||||
# apply rotary embedding
|
||||
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, seq_len:]])
|
||||
|
||||
# append to collection
|
||||
output.append(x_i)
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.LayerNorm):
|
||||
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
||||
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, seq_lens, grid_sizes, freqs):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context, context_lens):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C].
|
||||
context_lens: [B].
|
||||
"""
|
||||
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
||||
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
||||
v = self.v(context).view(b, -1, n, d)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanI2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6):
|
||||
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
self.k_img = nn.Linear(dim, dim)
|
||||
self.v_img = nn.Linear(dim, dim)
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = WanRMSNorm(
|
||||
dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_lens):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C].
|
||||
context_lens: [B].
|
||||
"""
|
||||
context_img = context[:, :257]
|
||||
context = context[:, 257:]
|
||||
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
||||
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
||||
v = self.v(context).view(b, -1, n, d)
|
||||
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
||||
v_img = self.v_img(context_img).view(b, -1, n, d)
|
||||
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
img_x = img_x.flatten(2)
|
||||
x = x + img_x
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
WANX_CROSSATTENTION_CLASSES = {
|
||||
't2v_cross_attn': WanT2VCrossAttention,
|
||||
'i2v_cross_attn': WanI2VCrossAttention,
|
||||
}
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||||
eps)
|
||||
self.norm3 = WanLayerNorm(
|
||||
dim, eps,
|
||||
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
|
||||
dim, num_heads, (-1, -1), qk_norm, eps)
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
||||
nn.Linear(ffn_dim, dim))
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
context_lens,
|
||||
):
|
||||
assert e.dtype == torch.float32
|
||||
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
||||
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
|
||||
assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
||||
freqs)
|
||||
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
||||
x = x + y * e[2]
|
||||
|
||||
# cross-attention & ffn function
|
||||
def cross_attn_ffn(x, context, context_lens, e):
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
||||
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
||||
x = x + y * e[5]
|
||||
return x
|
||||
|
||||
x = cross_attn_ffn(x, context, context_lens, e)
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, e):
|
||||
assert e.dtype == torch.float32
|
||||
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
||||
e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
return x
|
||||
|
||||
|
||||
class MLPProj(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
||||
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
||||
torch.nn.LayerNorm(out_dim))
|
||||
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6):
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ['t2v', 'i2v']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
# blocks
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.freqs = torch.cat([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6))
|
||||
],
|
||||
dim=1)
|
||||
|
||||
if model_type == 'i2v':
|
||||
self.img_emb = MLPProj(1280, dim)
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
seq_len,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
x: A list of videos each with shape [C, T, H, W].
|
||||
t: [B].
|
||||
context: A list of text embeddings each with shape [L, C].
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert clip_fea is not None and y is not None
|
||||
# params
|
||||
device = x[0].device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack(
|
||||
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in x
|
||||
])
|
||||
|
||||
# time embeddings
|
||||
with amp.autocast(dtype=torch.float32, device_type="cuda"):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in context
|
||||
]))
|
||||
|
||||
if clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, **kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
x = torch.stack(x).float()
|
||||
return x
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
c = self.out_dim
|
||||
out = []
|
||||
for u, v in zip(x, grid_sizes.tolist()):
|
||||
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
||||
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
||||
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def init_weights(self):
|
||||
# basic init
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
# init embeddings
|
||||
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
||||
for m in self.text_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=.02)
|
||||
for m in self.time_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=.02)
|
||||
|
||||
# init output layer
|
||||
nn.init.zeros_(self.head.head.weight)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanModelStateDictConverter()
|
||||
|
||||
|
||||
class WanModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
"text_len": 512,
|
||||
"in_dim": 16,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"window_size": (-1, -1),
|
||||
"qk_norm": True,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
"text_len": 512,
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"window_size": (-1, -1),
|
||||
"qk_norm": True,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"model_type": "i2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
"text_len": 512,
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"window_size": (-1, -1),
|
||||
"qk_norm": True,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
904
diffsynth/models/wan_video_image_encoder.py
Normal file
904
diffsynth/models/wan_video_image_encoder.py
Normal file
@@ -0,0 +1,904 @@
|
||||
"""
|
||||
Concise re-implementation of
|
||||
``https://github.com/openai/CLIP'' and
|
||||
``https://github.com/mlfoundations/open_clip''.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from .wan_video_dit import flash_attention
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
||||
|
||||
# compute attention
|
||||
p = self.dropout.p if self.training else 0.0
|
||||
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
||||
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
||||
|
||||
# output
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.post_norm = post_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
||||
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
||||
nn.Dropout(dropout))
|
||||
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, x, mask):
|
||||
if self.post_norm:
|
||||
x = self.norm1(x + self.attn(x, mask))
|
||||
x = self.norm2(x + self.ffn(x))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x), mask)
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class XLMRoberta(nn.Module):
|
||||
"""
|
||||
XLMRobertaModel with no pooler and no LM head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=250002,
|
||||
max_seq_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
dim=1024,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
post_norm=True,
|
||||
dropout=0.1,
|
||||
eps=1e-5):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.type_size = type_size
|
||||
self.pad_id = pad_id
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.post_norm = post_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
||||
self.type_embedding = nn.Embedding(type_size, dim)
|
||||
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# norm layer
|
||||
self.norm = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def forward(self, ids):
|
||||
"""
|
||||
ids: [B, L] of torch.LongTensor.
|
||||
"""
|
||||
b, s = ids.shape
|
||||
mask = ids.ne(self.pad_id).long()
|
||||
|
||||
# embeddings
|
||||
x = self.token_embedding(ids) + \
|
||||
self.type_embedding(torch.zeros_like(ids)) + \
|
||||
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
||||
if self.post_norm:
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
# blocks
|
||||
mask = torch.where(
|
||||
mask.view(b, 1, 1, s).gt(0), 0.0,
|
||||
torch.finfo(x.dtype).min)
|
||||
for block in self.blocks:
|
||||
x = block(x, mask)
|
||||
|
||||
# output
|
||||
if not self.post_norm:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def xlm_roberta_large(pretrained=False,
|
||||
return_tokenizer=False,
|
||||
device='cpu',
|
||||
**kwargs):
|
||||
"""
|
||||
XLMRobertaLarge adapted from Huggingface.
|
||||
"""
|
||||
# params
|
||||
cfg = dict(
|
||||
vocab_size=250002,
|
||||
max_seq_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
dim=1024,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
post_norm=True,
|
||||
dropout=0.1,
|
||||
eps=1e-5)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
if pretrained:
|
||||
from sora import DOWNLOAD_TO_CACHE
|
||||
|
||||
# init a meta model
|
||||
with torch.device('meta'):
|
||||
model = XLMRoberta(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
|
||||
map_location=device),
|
||||
assign=True)
|
||||
else:
|
||||
# init a model on device
|
||||
with torch.device(device):
|
||||
model = XLMRoberta(**cfg)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from sora.data import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(
|
||||
name='xlm-roberta-large',
|
||||
seq_len=model.text_len,
|
||||
clean='whitespace')
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def pos_interpolate(pos, seq_len):
|
||||
if pos.size(1) == seq_len:
|
||||
return pos
|
||||
else:
|
||||
src_grid = int(math.sqrt(pos.size(1)))
|
||||
tar_grid = int(math.sqrt(seq_len))
|
||||
n = pos.size(1) - src_grid * src_grid
|
||||
return torch.cat([
|
||||
pos[:, :n],
|
||||
F.interpolate(
|
||||
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
||||
0, 3, 1, 2),
|
||||
size=(tar_grid, tar_grid),
|
||||
mode='bicubic',
|
||||
align_corners=False).flatten(2).transpose(1, 2)
|
||||
],
|
||||
dim=1)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
causal=False,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.causal = causal
|
||||
self.attn_dropout = attn_dropout
|
||||
self.proj_dropout = proj_dropout
|
||||
|
||||
# layers
|
||||
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
||||
|
||||
# compute attention
|
||||
p = self.attn_dropout if self.training else 0.0
|
||||
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
||||
x = x.reshape(b, s, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = F.dropout(x, self.proj_dropout, self.training)
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim, mid_dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mid_dim = mid_dim
|
||||
|
||||
# layers
|
||||
self.fc1 = nn.Linear(dim, mid_dim)
|
||||
self.fc2 = nn.Linear(dim, mid_dim)
|
||||
self.fc3 = nn.Linear(mid_dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.silu(self.fc1(x)) * self.fc2(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
mlp_ratio,
|
||||
num_heads,
|
||||
post_norm=False,
|
||||
causal=False,
|
||||
activation='quick_gelu',
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.post_norm = post_norm
|
||||
self.causal = causal
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# layers
|
||||
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
||||
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
||||
proj_dropout)
|
||||
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
||||
if activation == 'swi_glu':
|
||||
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
||||
else:
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
||||
|
||||
def forward(self, x):
|
||||
if self.post_norm:
|
||||
x = x + self.norm1(self.attn(x))
|
||||
x = x + self.norm2(self.mlp(x))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
mlp_ratio,
|
||||
num_heads,
|
||||
activation='gelu',
|
||||
proj_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.proj_dropout = proj_dropout
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# layers
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||
self.to_q = nn.Linear(dim, dim)
|
||||
self.to_kv = nn.Linear(dim, dim * 2)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.norm = LayerNorm(dim, eps=norm_eps)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
||||
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, C].
|
||||
"""
|
||||
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
||||
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, version=2)
|
||||
x = x.reshape(b, 1, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = F.dropout(x, self.proj_dropout, self.training)
|
||||
|
||||
# mlp
|
||||
x = x + self.mlp(self.norm(x))
|
||||
return x[:, 0]
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
dim=768,
|
||||
mlp_ratio=4,
|
||||
out_dim=512,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
pool_type='token',
|
||||
pre_norm=True,
|
||||
post_norm=False,
|
||||
activation='quick_gelu',
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
if image_size % patch_size != 0:
|
||||
print(
|
||||
'[WARNING] image_size is not divisible by patch_size',
|
||||
flush=True)
|
||||
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
||||
out_dim = out_dim or dim
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = (image_size // patch_size)**2
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.pool_type = pool_type
|
||||
self.post_norm = post_norm
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# embeddings
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
3,
|
||||
dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=not pre_norm)
|
||||
if pool_type in ('token', 'token_fc'):
|
||||
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
||||
1, self.num_patches +
|
||||
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
||||
self.dropout = nn.Dropout(embedding_dropout)
|
||||
|
||||
# transformer
|
||||
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
||||
self.transformer = nn.Sequential(*[
|
||||
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
||||
activation, attn_dropout, proj_dropout, norm_eps)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
||||
|
||||
# head
|
||||
if pool_type == 'token':
|
||||
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||
elif pool_type == 'token_fc':
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
elif pool_type == 'attn_pool':
|
||||
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
||||
proj_dropout, norm_eps)
|
||||
|
||||
def forward(self, x, interpolation=False, use_31_block=False):
|
||||
b = x.size(0)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
||||
if self.pool_type in ('token', 'token_fc'):
|
||||
x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
|
||||
if interpolation:
|
||||
e = pos_interpolate(self.pos_embedding, x.size(1))
|
||||
else:
|
||||
e = self.pos_embedding
|
||||
e = e.to(dtype=x.dtype, device=x.device)
|
||||
x = self.dropout(x + e)
|
||||
if self.pre_norm is not None:
|
||||
x = self.pre_norm(x)
|
||||
|
||||
# transformer
|
||||
if use_31_block:
|
||||
x = self.transformer[:-1](x)
|
||||
return x
|
||||
else:
|
||||
x = self.transformer(x)
|
||||
return x
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
vision_dim=768,
|
||||
vision_mlp_ratio=4,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vision_pool='token',
|
||||
vision_pre_norm=True,
|
||||
vision_post_norm=False,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_mlp_ratio=4,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
text_causal=True,
|
||||
text_pool='argmax',
|
||||
text_head_bias=False,
|
||||
logit_bias=None,
|
||||
activation='quick_gelu',
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.vision_dim = vision_dim
|
||||
self.vision_mlp_ratio = vision_mlp_ratio
|
||||
self.vision_heads = vision_heads
|
||||
self.vision_layers = vision_layers
|
||||
self.vision_pool = vision_pool
|
||||
self.vision_pre_norm = vision_pre_norm
|
||||
self.vision_post_norm = vision_post_norm
|
||||
self.vocab_size = vocab_size
|
||||
self.text_len = text_len
|
||||
self.text_dim = text_dim
|
||||
self.text_mlp_ratio = text_mlp_ratio
|
||||
self.text_heads = text_heads
|
||||
self.text_layers = text_layers
|
||||
self.text_causal = text_causal
|
||||
self.text_pool = text_pool
|
||||
self.text_head_bias = text_head_bias
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# models
|
||||
self.visual = VisionTransformer(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
dim=vision_dim,
|
||||
mlp_ratio=vision_mlp_ratio,
|
||||
out_dim=embed_dim,
|
||||
num_heads=vision_heads,
|
||||
num_layers=vision_layers,
|
||||
pool_type=vision_pool,
|
||||
pre_norm=vision_pre_norm,
|
||||
post_norm=vision_post_norm,
|
||||
activation=activation,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout,
|
||||
norm_eps=norm_eps)
|
||||
self.textual = TextTransformer(
|
||||
vocab_size=vocab_size,
|
||||
text_len=text_len,
|
||||
dim=text_dim,
|
||||
mlp_ratio=text_mlp_ratio,
|
||||
out_dim=embed_dim,
|
||||
num_heads=text_heads,
|
||||
num_layers=text_layers,
|
||||
causal=text_causal,
|
||||
pool_type=text_pool,
|
||||
head_bias=text_head_bias,
|
||||
activation=activation,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout,
|
||||
norm_eps=norm_eps)
|
||||
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||
if logit_bias is not None:
|
||||
self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, imgs, txt_ids):
|
||||
"""
|
||||
imgs: [B, 3, H, W] of torch.float32.
|
||||
- mean: [0.48145466, 0.4578275, 0.40821073]
|
||||
- std: [0.26862954, 0.26130258, 0.27577711]
|
||||
txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
|
||||
"""
|
||||
xi = self.visual(imgs)
|
||||
xt = self.textual(txt_ids)
|
||||
return xi, xt
|
||||
|
||||
def init_weights(self):
|
||||
# embeddings
|
||||
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
|
||||
|
||||
# attentions
|
||||
for modality in ['visual', 'textual']:
|
||||
dim = self.vision_dim if modality == 'visual' else self.text_dim
|
||||
transformer = getattr(self, modality).transformer
|
||||
proj_gain = (1.0 / math.sqrt(dim)) * (
|
||||
1.0 / math.sqrt(2 * len(transformer)))
|
||||
attn_gain = 1.0 / math.sqrt(dim)
|
||||
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
||||
for block in transformer:
|
||||
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
||||
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
||||
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
||||
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
||||
|
||||
def param_groups(self):
|
||||
groups = [{
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if 'norm' in n or n.endswith('bias')
|
||||
],
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if not ('norm' in n or n.endswith('bias'))
|
||||
]
|
||||
}]
|
||||
return groups
|
||||
|
||||
|
||||
class XLMRobertaWithHead(XLMRoberta):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.out_dim = kwargs.pop('out_dim')
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# head
|
||||
mid_dim = (self.dim + self.out_dim) // 2
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
||||
nn.Linear(mid_dim, self.out_dim, bias=False))
|
||||
|
||||
def forward(self, ids):
|
||||
# xlm-roberta
|
||||
x = super().forward(ids)
|
||||
|
||||
# average pooling
|
||||
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
||||
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
||||
|
||||
# head
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class XLMRobertaCLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
embed_dim=1024,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
vision_dim=1280,
|
||||
vision_mlp_ratio=4,
|
||||
vision_heads=16,
|
||||
vision_layers=32,
|
||||
vision_pool='token',
|
||||
vision_pre_norm=True,
|
||||
vision_post_norm=False,
|
||||
activation='gelu',
|
||||
vocab_size=250002,
|
||||
max_text_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
text_dim=1024,
|
||||
text_heads=16,
|
||||
text_layers=24,
|
||||
text_post_norm=True,
|
||||
text_dropout=0.1,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0,
|
||||
norm_eps=1e-5):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.vision_dim = vision_dim
|
||||
self.vision_mlp_ratio = vision_mlp_ratio
|
||||
self.vision_heads = vision_heads
|
||||
self.vision_layers = vision_layers
|
||||
self.vision_pre_norm = vision_pre_norm
|
||||
self.vision_post_norm = vision_post_norm
|
||||
self.activation = activation
|
||||
self.vocab_size = vocab_size
|
||||
self.max_text_len = max_text_len
|
||||
self.type_size = type_size
|
||||
self.pad_id = pad_id
|
||||
self.text_dim = text_dim
|
||||
self.text_heads = text_heads
|
||||
self.text_layers = text_layers
|
||||
self.text_post_norm = text_post_norm
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
# models
|
||||
self.visual = VisionTransformer(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
dim=vision_dim,
|
||||
mlp_ratio=vision_mlp_ratio,
|
||||
out_dim=embed_dim,
|
||||
num_heads=vision_heads,
|
||||
num_layers=vision_layers,
|
||||
pool_type=vision_pool,
|
||||
pre_norm=vision_pre_norm,
|
||||
post_norm=vision_post_norm,
|
||||
activation=activation,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout,
|
||||
norm_eps=norm_eps)
|
||||
self.textual = None
|
||||
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||
|
||||
def forward(self, imgs, txt_ids):
|
||||
"""
|
||||
imgs: [B, 3, H, W] of torch.float32.
|
||||
- mean: [0.48145466, 0.4578275, 0.40821073]
|
||||
- std: [0.26862954, 0.26130258, 0.27577711]
|
||||
txt_ids: [B, L] of torch.long.
|
||||
Encoded by data.CLIPTokenizer.
|
||||
"""
|
||||
xi = self.visual(imgs)
|
||||
xt = self.textual(txt_ids)
|
||||
return xi, xt
|
||||
|
||||
def param_groups(self):
|
||||
groups = [{
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if 'norm' in n or n.endswith('bias')
|
||||
],
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if not ('norm' in n or n.endswith('bias'))
|
||||
]
|
||||
}]
|
||||
return groups
|
||||
|
||||
|
||||
def _clip(pretrained=False,
|
||||
pretrained_name=None,
|
||||
model_cls=CLIP,
|
||||
return_transforms=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_padding='eos',
|
||||
dtype=torch.float32,
|
||||
device='cpu',
|
||||
**kwargs):
|
||||
# init model
|
||||
if pretrained and pretrained_name:
|
||||
from sora import BUCKET, DOWNLOAD_TO_CACHE
|
||||
|
||||
# init a meta model
|
||||
with torch.device('meta'):
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# checkpoint path
|
||||
checkpoint = f'models/clip/{pretrained_name}'
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
suffix = '-' + {
|
||||
torch.float16: 'fp16',
|
||||
torch.bfloat16: 'bf16'
|
||||
}[dtype]
|
||||
if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
|
||||
checkpoint = f'{checkpoint}{suffix}'
|
||||
checkpoint += '.pth'
|
||||
|
||||
# load
|
||||
model.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
|
||||
assign=True,
|
||||
strict=False)
|
||||
else:
|
||||
# init a model on device
|
||||
with torch.device(device):
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# set device
|
||||
output = (model,)
|
||||
|
||||
# init transforms
|
||||
if return_transforms:
|
||||
# mean and std
|
||||
if 'siglip' in pretrained_name.lower():
|
||||
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
||||
else:
|
||||
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# transforms
|
||||
transforms = T.Compose([
|
||||
T.Resize((model.image_size, model.image_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=mean, std=std)
|
||||
])
|
||||
output += (transforms,)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from sora import data
|
||||
if 'siglip' in pretrained_name.lower():
|
||||
tokenizer = data.HuggingfaceTokenizer(
|
||||
name=f'timm/{pretrained_name}',
|
||||
seq_len=model.text_len,
|
||||
clean='canonicalize')
|
||||
elif 'xlm' in pretrained_name.lower():
|
||||
tokenizer = data.HuggingfaceTokenizer(
|
||||
name='xlm-roberta-large',
|
||||
seq_len=model.max_text_len - 2,
|
||||
clean='whitespace')
|
||||
elif 'mba' in pretrained_name.lower():
|
||||
tokenizer = data.HuggingfaceTokenizer(
|
||||
name='facebook/xlm-roberta-xl',
|
||||
seq_len=model.max_text_len - 2,
|
||||
clean='whitespace')
|
||||
else:
|
||||
tokenizer = data.CLIPTokenizer(
|
||||
seq_len=model.text_len, padding=tokenizer_padding)
|
||||
output += (tokenizer,)
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
|
||||
def clip_xlm_roberta_vit_h_14(
|
||||
pretrained=False,
|
||||
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
||||
**kwargs):
|
||||
cfg = dict(
|
||||
embed_dim=1024,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
vision_dim=1280,
|
||||
vision_mlp_ratio=4,
|
||||
vision_heads=16,
|
||||
vision_layers=32,
|
||||
vision_pool='token',
|
||||
activation='gelu',
|
||||
vocab_size=250002,
|
||||
max_text_len=514,
|
||||
type_size=1,
|
||||
pad_id=1,
|
||||
text_dim=1024,
|
||||
text_heads=16,
|
||||
text_layers=24,
|
||||
text_post_norm=True,
|
||||
text_dropout=0.1,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
||||
|
||||
|
||||
class WanImageEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# init model
|
||||
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
||||
pretrained=False,
|
||||
return_transforms=True,
|
||||
return_tokenizer=False,
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
|
||||
def encode_image(self, videos):
|
||||
# preprocess
|
||||
size = (self.model.image_size,) * 2
|
||||
videos = torch.cat([
|
||||
F.interpolate(
|
||||
u,
|
||||
size=size,
|
||||
mode='bicubic',
|
||||
align_corners=False) for u in videos
|
||||
])
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanImageEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanImageEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("textual."):
|
||||
continue
|
||||
name = "model." + name
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
269
diffsynth/models/wan_video_text_encoder.py
Normal file
269
diffsynth/models/wan_video_text_encoder.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
clamp = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp, max=clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super(T5LayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||
self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
||||
assert dim_attn % num_heads == 0
|
||||
super(T5Attention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
|
||||
# attention bias
|
||||
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
||||
if pos_bias is not None:
|
||||
attn_bias += pos_bias
|
||||
if mask is not None:
|
||||
assert mask.ndim in [2, 3]
|
||||
mask = mask.view(b, 1, 1,
|
||||
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
||||
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, n * c)
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_ffn, dropout=0.1):
|
||||
super(T5FeedForward, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) * self.gate(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def forward(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super(T5RelativeEmbedding, self).__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def forward(self, lq, lk):
|
||||
device = self.embedding.weight.device
|
||||
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
||||
# torch.arange(lq).unsqueeze(1).to(device)
|
||||
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
||||
torch.arange(lq, device=device).unsqueeze(1)
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
||||
0) # [1, N, Lq, Lk]
|
||||
return rel_pos_embeds.contiguous()
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).long() * num_buckets
|
||||
rel_pos = torch.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = 0
|
||||
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)).long()
|
||||
rel_pos_large = torch.min(
|
||||
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
||||
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
||||
return rel_buckets
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, T5LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
elif isinstance(m, T5FeedForward):
|
||||
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
||||
elif isinstance(m, T5Attention):
|
||||
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
||||
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
||||
elif isinstance(m, T5RelativeEmbedding):
|
||||
nn.init.normal_(
|
||||
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
||||
|
||||
|
||||
class WanTextEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
num_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.1):
|
||||
super(WanTextEncoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
807
diffsynth/models/wan_video_vae.py
Normal file
807
diffsynth/models/wan_video_vae.py
Normal file
@@ -0,0 +1,807 @@
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
def check_is_instance(model, module_class):
|
||||
if isinstance(model, module_class):
|
||||
return True
|
||||
if hasattr(model, "module") and isinstance(model.module, module_class):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def block_causal_mask(x, block_size):
|
||||
# params
|
||||
b, n, s, _, device = *x.size(), x.device
|
||||
assert s % block_size == 0
|
||||
num_blocks = s // block_size
|
||||
|
||||
# build mask
|
||||
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
||||
for i in range(num_blocks):
|
||||
mask[:, :,
|
||||
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim * 2, (3, 1, 1),
|
||||
padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim, (3, 1, 1),
|
||||
stride=(2, 1, 1),
|
||||
padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
||||
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
#attn_mask=block_causal_mask(q, block_size=h * w)
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if check_is_instance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class VideoVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=96,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
class WanVideoVAE(nn.Module):
|
||||
|
||||
def __init__(self, z_dim=16):
|
||||
super().__init__()
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = torch.tensor(mean)
|
||||
self.std = torch.tensor(std)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 8
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, _, H, W = data.shape
|
||||
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
||||
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
||||
|
||||
h = repeat(h, "H -> H W", H=H, W=W)
|
||||
w = repeat(w, "W -> H W", H=H, W=W)
|
||||
|
||||
mask = torch.stack([h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = hidden_states.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = T * 4 - 3
|
||||
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
||||
).to(dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
target_h = h * self.upsampling_factor
|
||||
target_w = w * self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float().clamp_(-1, 1)
|
||||
return values
|
||||
|
||||
|
||||
def tiled_encode(self, video, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = video.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = (T + 3) // 4
|
||||
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
||||
).to(dtype=video.dtype, device=data_device)
|
||||
|
||||
target_h = h // self.upsampling_factor
|
||||
target_w = w // self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float()
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, self.scale)
|
||||
return x.float()
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, self.scale)
|
||||
return video.float().clamp_(-1, 1)
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
|
||||
videos = [video.to("cpu") for video in videos]
|
||||
hidden_states = []
|
||||
for video in videos:
|
||||
video = video.unsqueeze(0)
|
||||
if tiled:
|
||||
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
||||
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||
else:
|
||||
hidden_state = self.single_encode(video, device)
|
||||
hidden_state = hidden_state.squeeze(0)
|
||||
hidden_states.append(hidden_state)
|
||||
hidden_states = torch.stack(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanVideoVAEStateDictConverter()
|
||||
|
||||
|
||||
class WanVideoVAEStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
if 'model_state' in state_dict:
|
||||
state_dict = state_dict['model_state']
|
||||
for name in state_dict:
|
||||
state_dict_['model.' + name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -11,4 +11,5 @@ from .omnigen_image import OmnigenImagePipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
from .hunyuan_video import HunyuanVideoPipeline
|
||||
from .step_video import StepVideoPipeline
|
||||
from .wan_video import WanVideoPipeline
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
|
||||
267
diffsynth/pipelines/wan_video.py
Normal file
267
diffsynth/pipelines/wan_video.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import WanPrompter
|
||||
import torch, os
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
|
||||
|
||||
|
||||
class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
||||
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
dtype = next(iter(self.text_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
T5RelativeEmbedding: AutoWrappedModule,
|
||||
T5LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
WanLayerNorm: AutoWrappedModule,
|
||||
WanRMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
RMS_norm: AutoWrappedModule,
|
||||
CausalConv3d: AutoWrappedModule,
|
||||
Upsample: AutoWrappedModule,
|
||||
torch.nn.SiLU: AutoWrappedModule,
|
||||
torch.nn.Dropout: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.image_encoder is not None:
|
||||
dtype = next(iter(self.image_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.image_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
|
||||
if text_encoder_model_and_path is not None:
|
||||
self.text_encoder, tokenizer_path = text_encoder_model_and_path
|
||||
self.prompter.fetch_models(self.text_encoder)
|
||||
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
|
||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
def encode_image(self, image, height, width):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, 81, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, 80, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
return {"clip_fea": clip_context, "y": [y]}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
||||
|
||||
|
||||
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return frames
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_image=None,
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
tiled=True,
|
||||
tile_size=(34, 34),
|
||||
tile_stride=(18, 16),
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Initialize noise
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
||||
if input_video is not None:
|
||||
self.load_models_to_device(['vae'])
|
||||
input_video = self.preprocess_images(input_video)
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Encode image
|
||||
if input_image is not None and self.image_encoder is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.encode_image(input_image, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
frames = self.decode_video(latents, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
@@ -9,3 +9,4 @@ from .omost import OmostPromter
|
||||
from .cog_prompter import CogPrompter
|
||||
from .hunyuan_video_prompter import HunyuanVideoPrompter
|
||||
from .stepvideo_prompter import StepVideoPrompter
|
||||
from .wan_prompter import WanPrompter
|
||||
|
||||
103
diffsynth/prompters/wan_prompter.py
Normal file
103
diffsynth/prompters/wan_prompter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from transformers import AutoTokenizer
|
||||
import os, torch
|
||||
import html
|
||||
import string
|
||||
import regex as re
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
class WanPrompter(BasePrompter):
|
||||
|
||||
def __init__(self, tokenizer_path=None, text_len=512):
|
||||
super().__init__()
|
||||
self.text_len = text_len
|
||||
self.text_encoder = None
|
||||
self.fetch_tokenizer(tokenizer_path)
|
||||
|
||||
def fetch_tokenizer(self, tokenizer_path=None):
|
||||
if tokenizer_path is not None:
|
||||
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
|
||||
|
||||
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
||||
ids = ids.to(device)
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_emb = self.text_encoder(ids, mask)
|
||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
||||
return prompt_emb
|
||||
144
examples/wanvideo/README.md
Normal file
144
examples/wanvideo/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Wan-Video
|
||||
|
||||
Wan-Video is a collection of video synthesis models open-sourced by Alibaba.
|
||||
|
||||
## Inference
|
||||
|
||||
### Wan-Video-1.3B-T2V
|
||||
|
||||
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
|
||||
|
||||
Required VRAM: 6G
|
||||
|
||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||
|
||||
Put sunglasses on the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
|
||||
### Wan-Video-14B-T2V
|
||||
|
||||
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
|
||||
We present a detailed table here. The model is tested on a single A100.
|
||||
|
||||
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
||||
|-|-|-|-|-|
|
||||
|torch.bfloat16|None (unlimited)|18.5s/it|40G||
|
||||
|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G||
|
||||
|torch.bfloat16|0|23.4s/it|10G||
|
||||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
||||
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
||||
|
||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||
|
||||
### Wan-Video-14B-I2V
|
||||
|
||||
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
||||
|
||||

|
||||
|
||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
||||
|
||||
## Train
|
||||
|
||||
We support Wan-Video LoRA training. Here is a tutorial.
|
||||
|
||||
Step 1: Install additional packages
|
||||
|
||||
```
|
||||
pip install peft lightning pandas
|
||||
```
|
||||
|
||||
Step 2: Prepare your dataset
|
||||
|
||||
You need to manage the training videos as follows:
|
||||
|
||||
```
|
||||
data/example_dataset/
|
||||
├── metadata.csv
|
||||
└── train
|
||||
├── video_00001.mp4
|
||||
└── video_00002.mp4
|
||||
```
|
||||
|
||||
`metadata.csv`:
|
||||
|
||||
```
|
||||
file_name,text
|
||||
video_00001.mp4,"video description"
|
||||
video_00001.mp4,"video description"
|
||||
```
|
||||
|
||||
Step 3: Data process
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
||||
--task data_process \
|
||||
--dataset_path data/example_dataset \
|
||||
--output_path ./models \
|
||||
--text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" \
|
||||
--vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" \
|
||||
--tiled \
|
||||
--num_frames 81 \
|
||||
--height 480 \
|
||||
--width 832
|
||||
```
|
||||
|
||||
After that, some cached files will be stored in the dataset folder.
|
||||
|
||||
```
|
||||
data/example_dataset/
|
||||
├── metadata.csv
|
||||
└── train
|
||||
├── video_00001.mp4
|
||||
├── video_00001.mp4.tensors.pth
|
||||
├── video_00002.mp4
|
||||
└── video_00002.mp4.tensors.pth
|
||||
```
|
||||
|
||||
Step 4: Train
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
|
||||
--task train \
|
||||
--dataset_path data/example_dataset \
|
||||
--output_path ./models \
|
||||
--dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \
|
||||
--steps_per_epoch 500 \
|
||||
--max_epochs 10 \
|
||||
--learning_rate 1e-4 \
|
||||
--lora_rank 4 \
|
||||
--lora_alpha 4 \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--accumulate_grad_batches 1 \
|
||||
--use_gradient_checkpointing
|
||||
```
|
||||
|
||||
Step 5: Test
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
])
|
||||
model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0)
|
||||
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="...",
|
||||
negative_prompt="...",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_with_lora.mp4", fps=30, quality=5)
|
||||
```
|
||||
460
examples/wanvideo/train_wan_t2v.py
Normal file
460
examples/wanvideo/train_wan_t2v.py
Normal file
@@ -0,0 +1,460 @@
|
||||
import torch, os, imageio, argparse
|
||||
from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
import lightning as pl
|
||||
import pandas as pd
|
||||
from diffsynth import WanVideoPipeline, ModelManager
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
|
||||
self.max_num_frames = max_num_frames
|
||||
self.frame_interval = frame_interval
|
||||
self.num_frames = num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.frame_process = v2.Compose([
|
||||
v2.Resize(size=(height, width), antialias=True),
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
v2.ToTensor(),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
def crop_and_resize(self, image):
|
||||
width, height = image.size
|
||||
scale = max(self.width / width, self.height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||
reader = imageio.get_reader(file_path)
|
||||
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||
reader.close()
|
||||
return None
|
||||
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame)
|
||||
frame = frame_process(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
|
||||
frames = torch.stack(frames, dim=0)
|
||||
frames = rearrange(frames, "T C H W -> C T H W")
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
|
||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
|
||||
return frames
|
||||
|
||||
|
||||
def load_text_video_raw_data(self, data_id):
|
||||
text = self.path[data_id]
|
||||
video = self.load_video(self.path[data_id])
|
||||
data = {"text": text, "video": video}
|
||||
return data
|
||||
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
text = self.path[data_id]
|
||||
path = self.path[data_id]
|
||||
video = self.load_video(path)
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path)
|
||||
|
||||
|
||||
|
||||
class LightningModelForDataProcess(pl.LightningModule):
|
||||
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([text_encoder_path, vae_path])
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
||||
self.pipe.device = self.device
|
||||
if video is not None:
|
||||
prompt_emb = self.pipe.encode_prompt(text)
|
||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb}
|
||||
torch.save(data, path + ".tensors.pth")
|
||||
|
||||
|
||||
|
||||
class TensorDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, steps_per_epoch):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||
print(len(self.path), "tensors cached in metadata.")
|
||||
assert len(self.path) > 0
|
||||
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
path = self.path[data_id]
|
||||
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class LightningModelForTrain(pl.LightningModule):
|
||||
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([dit_path])
|
||||
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
self.freeze_parameters()
|
||||
self.add_lora_to_model(
|
||||
self.pipe.denoising_model(),
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_target_modules=lora_target_modules,
|
||||
init_lora_weights=init_lora_weights,
|
||||
)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
# Freeze parameters
|
||||
self.pipe.requires_grad_(False)
|
||||
self.pipe.eval()
|
||||
self.pipe.denoising_model().train()
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
|
||||
# Add LoRA to UNet
|
||||
self.lora_alpha = lora_alpha
|
||||
if init_lora_weights == "kaiming":
|
||||
init_lora_weights = True
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights=init_lora_weights,
|
||||
target_modules=lora_target_modules.split(","),
|
||||
)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
for param in model.parameters():
|
||||
# Upcast LoRA parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
latents = batch["latents"].to(self.device)
|
||||
prompt_emb = batch["prompt_emb"]
|
||||
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
|
||||
|
||||
# Loss
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
noise_pred = self.pipe.denoising_model()(
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.denoising_model().state_dict()
|
||||
lora_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
lora_state_dict[name] = param
|
||||
checkpoint.update(lora_state_dict)
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="data_process",
|
||||
required=True,
|
||||
choices=["data_process", "train"],
|
||||
help="Task. `data_process` or `train`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The path of the Dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./",
|
||||
help="Path to save the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of text encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of DiT.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiled",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_size_height",
|
||||
type=int,
|
||||
default=34,
|
||||
help="Tile size (height) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_size_width",
|
||||
type=int,
|
||||
default=34,
|
||||
help="Tile size (width) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_stride_height",
|
||||
type=int,
|
||||
default=18,
|
||||
help="Tile stride (height) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tile_stride_width",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Tile stride (width) in VAE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps per epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_frames",
|
||||
type=int,
|
||||
default=81,
|
||||
help="Number of frames.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=480,
|
||||
help="Image height.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=832,
|
||||
help="Image width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="Learning rate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of batches in gradient accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_target_modules",
|
||||
type=str,
|
||||
default="q,k,v,o,ffn.0,ffn.2",
|
||||
help="Layers with LoRA modules.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_lora_weights",
|
||||
type=str,
|
||||
default="kaiming",
|
||||
choices=["gaussian", "kaiming"],
|
||||
help="The initializing method of LoRA weight.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training_strategy",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||
help="Training strategy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The dimension of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_process(args):
|
||||
dataset = TextVideoDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv"),
|
||||
max_num_frames=args.num_frames,
|
||||
frame_interval=1,
|
||||
num_frames=args.num_frames,
|
||||
height=args.height,
|
||||
width=args.width
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
model = LightningModelForDataProcess(
|
||||
text_encoder_path=args.text_encoder_path,
|
||||
vae_path=args.vae_path,
|
||||
tiled=args.tiled,
|
||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||
)
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
default_root_dir=args.output_path,
|
||||
)
|
||||
trainer.test(model, dataloader)
|
||||
|
||||
|
||||
def train(args):
|
||||
dataset = TensorDataset(
|
||||
args.dataset_path,
|
||||
os.path.join(args.dataset_path, "metadata.csv"),
|
||||
steps_per_epoch=args.steps_per_epoch,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=1,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
model = LightningModelForTrain(
|
||||
dit_path=args.dit_path,
|
||||
learning_rate=args.learning_rate,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
||||
)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||
)
|
||||
trainer.fit(model, dataloader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
if args.task == "data_process":
|
||||
data_process(args)
|
||||
elif args.task == "train":
|
||||
train(args)
|
||||
40
examples/wanvideo/wan_1.3b_text_to_video.py
Normal file
40
examples/wanvideo/wan_1.3b_text_to_video.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", cache_dir="models")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Video-to-video
|
||||
video = VideoData("video1.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_video=video, denoising_strength=0.7,
|
||||
num_inference_steps=50,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
48
examples/wanvideo/wan_14b_image_to_video.py
Normal file
48
examples/wanvideo/wan_14b_image_to_video.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download, dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", cache_dir="models")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
|
||||
# Download example image
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# Image-to-video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=image,
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
38
examples/wanvideo/wan_14b_text_to_video.py
Normal file
38
examples/wanvideo/wan_14b_text_to_video.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", cache_dir="models")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00007.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00007-of-00007.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=25, quality=5)
|
||||
Reference in New Issue
Block a user