From b548d7caf2227cb0d6f6f5d7d14fa673ffc066ce Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 7 Mar 2025 16:35:26 +0800 Subject: [PATCH 1/4] refactor wan dit --- diffsynth/models/wan_video_dit.py | 883 ++++++++-------------------- diffsynth/pipelines/wan_video.py | 32 +- diffsynth/prompters/wan_prompter.py | 3 +- 3 files changed, 254 insertions(+), 664 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index fa59dc6..ff9ce50 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -1,11 +1,10 @@ -import math - import torch -import torch.amp as amp import torch.nn as nn -from tqdm import tqdm +import torch.nn.functional as F +import math +from typing import Tuple, Optional +from einops import rearrange from .utils import hash_state_dict_keys - try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True @@ -23,710 +22,311 @@ try: SAGE_ATTN_AVAILABLE = True except ModuleNotFoundError: SAGE_ATTN_AVAILABLE = False - -import warnings - - -__all__ = ['WanModel'] - - -def flash_attention( - q, - k, - v, - q_lens=None, - k_lens=None, - dropout_p=0., - softmax_scale=None, - q_scale=None, - causal=False, - window_size=(-1, -1), - deterministic=False, - dtype=torch.bfloat16, - version=None, -): - """ - q: [B, Lq, Nq, C1]. - k: [B, Lk, Nk, C1]. - v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. - q_lens: [B]. - k_lens: [B]. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - causal: bool. Whether to apply causal attention mask. - window_size: (left right). If not (-1, -1), apply sliding window local attention. - deterministic: bool. If True, slightly slower and uses more memory. - dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. - """ - half_dtypes = (torch.float16, torch.bfloat16) - assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 - - # params - b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype - - def half(x): - return x if x.dtype in half_dtypes else x.to(dtype) - - # preprocess query - if q_lens is None: - q = half(q.flatten(0, 1)) - q_lens = torch.tensor( - [lq] * b, dtype=torch.int32).to( - device=q.device, non_blocking=True) - else: - q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) - - # preprocess key, value - if k_lens is None: - k = half(k.flatten(0, 1)) - v = half(v.flatten(0, 1)) - k_lens = torch.tensor( - [lk] * b, dtype=torch.int32).to( - device=k.device, non_blocking=True) - else: - k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) - v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) - - q = q.to(v.dtype) - k = k.to(v.dtype) - - if q_scale is not None: - q = q * q_scale - - if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: - warnings.warn( - 'Flash attention 3 is not available, use flash attention 2 instead.' - ) - - # apply attention - if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: - # Note: dropout_p, window_size are not supported in FA3 now. - x = flash_attn_interface.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - seqused_q=None, - seqused_k=None, - max_seqlen_q=lq, - max_seqlen_k=lk, - softmax_scale=softmax_scale, - causal=causal, - deterministic=deterministic)[0].unflatten(0, (b, lq)) + + +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int): + if FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) elif FLASH_ATTN_2_AVAILABLE: - x = flash_attn.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - max_seqlen_q=lq, - max_seqlen_k=lk, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - deterministic=deterministic).unflatten(0, (b, lq)) + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) elif SAGE_ATTN_AVAILABLE: - q = q.unsqueeze(0).transpose(1, 2).to(dtype) - k = k.unsqueeze(0).transpose(1, 2).to(dtype) - v = v.unsqueeze(0).transpose(1, 2).to(dtype) - x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal) - x = x.transpose(1, 2).contiguous() + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) else: - q = q.unsqueeze(0).transpose(1, 2).to(dtype) - k = k.unsqueeze(0).transpose(1, 2).to(dtype) - v = v.unsqueeze(0).transpose(1, 2).to(dtype) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) - x = x.transpose(1, 2).contiguous() - - # output - return x.type(out_dtype) - - -def create_sdpa_mask(q, k, q_lens, k_lens, causal=False): - b, lq, lk = q.size(0), q.size(1), k.size(1) - if q_lens is None: - q_lens = torch.tensor([lq] * b, dtype=torch.int32) - if k_lens is None: - k_lens = torch.tensor([lk] * b, dtype=torch.int32) - attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool) - for i in range(b): - q_len, k_len = q_lens[i], k_lens[i] - attn_mask[i, q_len:, :] = True - attn_mask[i, :, k_len:] = True - - if causal: - causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1) - attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask) - - attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True) - return attn_mask - - -def attention( - q, - k, - v, - q_lens=None, - k_lens=None, - dropout_p=0., - softmax_scale=None, - q_scale=None, - causal=False, - window_size=(-1, -1), - deterministic=False, - dtype=torch.bfloat16, - fa_version=None, -): - if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: - return flash_attention( - q=q, - k=k, - v=v, - q_lens=q_lens, - k_lens=k_lens, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - q_scale=q_scale, - causal=causal, - window_size=window_size, - deterministic=deterministic, - dtype=dtype, - version=fa_version, - ) - else: - if q_lens is not None or k_lens is not None: - warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.') - attn_mask = None - - q = q.transpose(1, 2).to(dtype) - k = k.transpose(1, 2).to(dtype) - v = v.transpose(1, 2).to(dtype) - - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) - - out = out.transpose(1, 2).contiguous() - return out - - - -def sinusoidal_embedding_1d(dim, position): - # preprocess - assert dim % 2 == 0 - half = dim // 2 - position = position.type(torch.float64) - - # calculation - sinusoid = torch.outer( - position, torch.pow(10000, -torch.arange(half).to(position).div(half))) - x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) return x -@amp.autocast(enabled=False, device_type="cuda") -def rope_params(max_seq_len, dim, theta=10000): - assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float64).div(dim))) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) -@amp.autocast(enabled=False, device_type="cuda") -def rope_apply(x, grid_sizes, freqs): - n, c = x.size(2), x.size(3) // 2 - - # split freqs - freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) - - # loop over samples - output = [] - for i, (f, h, w) in enumerate(grid_sizes.tolist()): - seq_len = f * h * w - - # precompute multipliers - x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( - seq_len, n, -1, 2)) - freqs_i = torch.cat([ - freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], - dim=-1).reshape(seq_len, 1, -1) - - # apply rotary embedding - x_i = torch.view_as_real(x_i * freqs_i).flatten(2) - x_i = torch.cat([x_i, x[i, seq_len:]]) - - # append to collection - output.append(x_i) - return torch.stack(output).float() +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) -class WanRMSNorm(nn.Module): +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + # 3d rope precompute + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + # 1d rope precompute + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() - self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - def forward(self, x): - return self._norm(x.float()).type_as(x) * self.weight - - def _norm(self, x): + def norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - -class WanLayerNorm(nn.LayerNorm): - - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - def forward(self, x): - return super().forward(x.float()).type_as(x) + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight -class WanSelfAttention(nn.Module): - - def __init__(self, - dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - eps=1e-6): - assert dim % num_heads == 0 +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads - self.window_size = window_size - self.qk_norm = qk_norm - self.eps = eps - # layers self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) - self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - - def forward(self, x, seq_lens, grid_sizes, freqs): - b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - - # query, key, value function - def qkv_fn(x): - q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x).view(b, s, n, d) - return q, k, v - - q, k, v = qkv_fn(x) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + def forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) x = flash_attention( - q=rope_apply(q, grid_sizes, freqs), - k=rope_apply(k, grid_sizes, freqs), + q=rope_apply(q, freqs, self.num_heads), + k=rope_apply(k, freqs, self.num_heads), v=v, - k_lens=seq_lens, - window_size=self.window_size) - - # output - x = x.flatten(2) - x = self.o(x) - return x + num_heads=self.num_heads + ) + return self.o(x) -class WanT2VCrossAttention(WanSelfAttention): - - def forward(self, x, context, context_lens): - """ - x: [B, L1, C]. - context: [B, L2, C]. - context_lens: [B]. - """ - b, n, d = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.norm_q(self.q(x)).view(b, -1, n, d) - k = self.norm_k(self.k(context)).view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) - - # compute attention - x = flash_attention(q, k, v, k_lens=context_lens) - - # output - x = x.flatten(2) - x = self.o(x) - return x - - -class WanI2VCrossAttention(WanSelfAttention): - - def __init__(self, - dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - eps=1e-6): - super().__init__(dim, num_heads, window_size, qk_norm, eps) - - self.k_img = nn.Linear(dim, dim) - self.v_img = nn.Linear(dim, dim) - # self.alpha = nn.Parameter(torch.zeros((1, ))) - self.norm_k_img = WanRMSNorm( - dim, eps=eps) if qk_norm else nn.Identity() - - def forward(self, x, context, context_lens): - """ - x: [B, L1, C]. - context: [B, L2, C]. - context_lens: [B]. - """ - context_img = context[:, :257] - context = context[:, 257:] - b, n, d = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.norm_q(self.q(x)).view(b, -1, n, d) - k = self.norm_k(self.k(context)).view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) - k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) - v_img = self.v_img(context_img).view(b, -1, n, d) - img_x = flash_attention(q, k_img, v_img, k_lens=None) - # compute attention - x = flash_attention(q, k, v, k_lens=context_lens) - - # output - x = x.flatten(2) - img_x = img_x.flatten(2) - x = x + img_x - x = self.o(x) - return x - - -WANX_CROSSATTENTION_CLASSES = { - 't2v_cross_attn': WanT2VCrossAttention, - 'i2v_cross_attn': WanI2VCrossAttention, -} - - -class WanAttentionBlock(nn.Module): - - def __init__(self, - cross_attn_type, - dim, - ffn_dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6): +class CrossAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): super().__init__() self.dim = dim - self.ffn_dim = ffn_dim self.num_heads = num_heads - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps + self.head_dim = dim // num_heads - # layers - self.norm1 = WanLayerNorm(dim, eps) - self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, - eps) - self.norm3 = WanLayerNorm( - dim, eps, - elementwise_affine=True) if cross_attn_norm else nn.Identity() - self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type]( - dim, num_heads, (-1, -1), qk_norm, eps) - self.norm2 = WanLayerNorm(dim, eps) - self.ffn = nn.Sequential( - nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), - nn.Linear(ffn_dim, dim)) + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + self.has_image_input = has_image_input + if has_image_input: + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = RMSNorm(dim, eps=eps) - # modulation + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.has_image_input: + img = y[:, :257] + ctx = y[:, 257:] + else: + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + x = flash_attention(q, k, v, num_heads=self.num_heads) + if self.has_image_input: + k_img = self.norm_k_img(self.k_img(img)) + v_img = self.v_img(img) + y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) + x = x + y + return self.o(x) + + +class DiTBlock(nn.Module): + def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention( + dim, num_heads, eps, has_image_input=has_image_input) + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) - def forward( - self, - x, - e, - seq_lens, - grid_sizes, - freqs, - context, - context_lens, - ): - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32, device_type="cuda"): - e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1) - assert e[0].dtype == torch.float32 - - # self-attention - y = self.self_attn( - self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, - freqs) - with amp.autocast(dtype=torch.float32, device_type="cuda"): - x = x + y * e[2] - - # cross-attention & ffn function - def cross_attn_ffn(x, context, context_lens, e): - x = x + self.cross_attn(self.norm3(x), context, context_lens) - y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) - with amp.autocast(dtype=torch.float32, device_type="cuda"): - x = x + y * e[5] - return x - - x = cross_attn_ffn(x, context, context_lens, e) + def forward(self, x, context, t_mod, freqs): + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = x + gate_msa * self.self_attn(input_x, freqs) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = x + gate_mlp * self.ffn(input_x) return x +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + + def forward(self, x): + return self.proj(x) + + class Head(nn.Module): - - def __init__(self, dim, out_dim, patch_size, eps=1e-6): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): super().__init__() self.dim = dim - self.out_dim = out_dim self.patch_size = patch_size - self.eps = eps - - # layers - out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) - self.head = nn.Linear(dim, out_dim) - - # modulation + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) - def forward(self, x, e): - assert e.dtype == torch.float32 - with amp.autocast(dtype=torch.float32, device_type="cuda"): - e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + def forward(self, x, t_mod): + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) return x -class MLPProj(torch.nn.Module): - - def __init__(self, in_dim, out_dim): +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + ): super().__init__() - - self.proj = torch.nn.Sequential( - torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), - torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), - torch.nn.LayerNorm(out_dim)) - - def forward(self, image_embeds): - clip_extra_context_tokens = self.proj(image_embeds) - return clip_extra_context_tokens - - -class WanModel(nn.Module): - - def __init__(self, - model_type='t2v', - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2048, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6): - super().__init__() - - assert model_type in ['t2v', 'i2v'] - self.model_type = model_type - - self.patch_size = patch_size - self.text_len = text_len - self.in_dim = in_dim self.dim = dim - self.ffn_dim = ffn_dim self.freq_dim = freq_dim - self.text_dim = text_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps + self.has_image_input = has_image_input + self.patch_size = patch_size - # embeddings self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.text_embedding = nn.Sequential( - nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), - nn.Linear(dim, dim)) - + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) self.time_embedding = nn.Sequential( - nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) - self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) - - # blocks - cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) self.blocks = nn.ModuleList([ - WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps) + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) for _ in range(num_layers) ]) - - # head self.head = Head(dim, out_dim, patch_size, eps) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) - # buffers (don't use register_buffer otherwise dtype will be changed in to()) - assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 - d = dim // num_heads - self.freqs = torch.cat([ - rope_params(1024, d - 4 * (d // 6)), - rope_params(1024, 2 * (d // 6)), - rope_params(1024, 2 * (d // 6)) - ], - dim=1) + if has_image_input: + self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280 - if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim) + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) - # initialize weights - self.init_weights() + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) - def forward( - self, - x, - timestep, - context, - seq_len, - clip_fea=None, - y=None, - use_gradient_checkpointing=False, - **kwargs, - ): - """ - x: A list of videos each with shape [C, T, H, W]. - t: [B]. - context: A list of text embeddings each with shape [L, C]. - """ - if self.model_type == 'i2v': - assert clip_fea is not None and y is not None - # params - device = x[0].device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) - - if y is not None: - x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] - - # embeddings - x = [self.patch_embedding(u.unsqueeze(0)) for u in x] - grid_sizes = torch.stack( - [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) - x = [u.flatten(2).transpose(1, 2) for u in x] - seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) - assert seq_lens.max() <= seq_len - x = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in x - ]) - - # time embeddings - with amp.autocast(dtype=torch.float32, device_type="cuda"): - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep).float()) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 - - # context - context_lens = None - context = self.text_embedding( - torch.stack([ - torch.cat( - [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) - for u in context - ])) - - if clip_fea is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) - - # arguments - kwargs = dict( - e=e0, - seq_lens=seq_lens, - grid_sizes=grid_sizes, - freqs=self.freqs, - context=context, - context_lens=context_lens) + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + x, (f, h, w) = self.patchify(x) + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) def create_custom_forward(module): - def custom_forward(*inputs, **kwargs): - return module(*inputs, **kwargs) + def custom_forward(*inputs): + return module(*inputs) return custom_forward for block in self.blocks: if self.training and use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, **kwargs, + x, context, t_mod, freqs, use_reentrant=False, ) else: - x = block(x, **kwargs) + x = block(x, context, t_mod, freqs) - # head - x = self.head(x, e) - - # unpatchify - x = self.unpatchify(x, grid_sizes) - x = torch.stack(x).float() + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) return x - def unpatchify(self, x, grid_sizes): - c = self.out_dim - out = [] - for u, v in zip(x, grid_sizes.tolist()): - u = u[:math.prod(v)].view(*v, *self.patch_size, c) - u = torch.einsum('fhwpqrc->cfphqwr', u) - u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) - out.append(u) - return out - - def init_weights(self): - # basic init - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - - # init embeddings - nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) - for m in self.text_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=.02) - for m in self.time_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=.02) - - # init output layer - nn.init.zeros_(self.head.head.weight) - @staticmethod def state_dict_converter(): return WanModelStateDictConverter() @@ -737,7 +337,8 @@ class WanModelStateDictConverter: pass def from_diffusers(self, state_dict): - rename_dict = {"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", @@ -815,9 +416,8 @@ class WanModelStateDictConverter: def from_civitai(self, state_dict): if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": config = { - "model_type": "t2v", - "patch_size": (1, 2, 2), - "text_len": 512, + "has_image_input": False, + "patch_size": [1, 2, 2], "in_dim": 16, "dim": 1536, "ffn_dim": 8960, @@ -826,16 +426,12 @@ class WanModelStateDictConverter: "out_dim": 16, "num_heads": 12, "num_layers": 30, - "window_size": (-1, -1), - "qk_norm": True, - "cross_attn_norm": True, - "eps": 1e-6, + "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": config = { - "model_type": "t2v", - "patch_size": (1, 2, 2), - "text_len": 512, + "has_image_input": False, + "patch_size": [1, 2, 2], "in_dim": 16, "dim": 5120, "ffn_dim": 13824, @@ -844,16 +440,12 @@ class WanModelStateDictConverter: "out_dim": 16, "num_heads": 40, "num_layers": 40, - "window_size": (-1, -1), - "qk_norm": True, - "cross_attn_norm": True, - "eps": 1e-6, + "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": config = { - "model_type": "i2v", - "patch_size": (1, 2, 2), - "text_len": 512, + "has_image_input": True, + "patch_size": [1, 2, 2], "in_dim": 36, "dim": 5120, "ffn_dim": 13824, @@ -862,10 +454,7 @@ class WanModelStateDictConverter: "out_dim": 16, "num_heads": 40, "num_layers": 40, - "window_size": (-1, -1), - "qk_norm": True, - "cross_attn_norm": True, - "eps": 1e-6, + "eps": 1e-6 } else: config = {} diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 7a864e3..45ef3b3 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -14,7 +14,7 @@ from tqdm import tqdm from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm -from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm +from ..models.wan_video_dit import RMSNorm from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample @@ -60,8 +60,8 @@ class WanVideoPipeline(BasePipeline): torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, - WanLayerNorm: AutoWrappedModule, - WanRMSNorm: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, @@ -224,7 +224,8 @@ class WanVideoPipeline(BasePipeline): self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift) # Initialize noise - noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device) + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32) + noise = noise.to(dtype=self.torch_dtype, device=self.device) if input_video is not None: self.load_models_to_device(['vae']) input_video = self.preprocess_images(input_video) @@ -252,20 +253,19 @@ class WanVideoPipeline(BasePipeline): # Denoise self.load_models_to_device(["dit"]) - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device) + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - # Inference - noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input) - if cfg_scale != 1.0: - noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input) - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) - else: - noise_pred = noise_pred_posi + # Inference + noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input) + if cfg_scale != 1.0: + noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi - # Scheduler - latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + # Scheduler + latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) # Decode self.load_models_to_device(['vae']) diff --git a/diffsynth/prompters/wan_prompter.py b/diffsynth/prompters/wan_prompter.py index f8c924a..01a765d 100644 --- a/diffsynth/prompters/wan_prompter.py +++ b/diffsynth/prompters/wan_prompter.py @@ -104,5 +104,6 @@ class WanPrompter(BasePrompter): mask = mask.to(device) seq_lens = mask.gt(0).sum(dim=1).long() prompt_emb = self.text_encoder(ids, mask) - prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)] + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 return prompt_emb From a05f6476331bfc072b1638ffc3f22a65b91bbf9f Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 10 Mar 2025 17:11:11 +0800 Subject: [PATCH 2/4] vram optimization --- diffsynth/models/wan_video_dit.py | 22 ++++++++--- diffsynth/models/wan_video_image_encoder.py | 18 ++++----- diffsynth/models/wan_video_vae.py | 7 ++-- diffsynth/pipelines/wan_video.py | 41 +++++++++---------- examples/wanvideo/train_wan_t2v.py | 44 +++++++++++++++------ examples/wanvideo/wan_14b_image_to_video.py | 2 +- 6 files changed, 83 insertions(+), 51 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index ff9ce50..f1e5e47 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -291,17 +291,21 @@ class WanModel(torch.nn.Module): clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, **kwargs, ): t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) context = self.text_embedding(context) + if self.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) + x, (f, h, w) = self.patchify(x) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), @@ -315,11 +319,19 @@ class WanModel(torch.nn.Module): for block in self.blocks: if self.training and use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) else: x = block(x, context, t_mod, freqs) diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py index 35f5ea3..b49235b 100644 --- a/diffsynth/models/wan_video_image_encoder.py +++ b/diffsynth/models/wan_video_image_encoder.py @@ -228,7 +228,7 @@ class QuickGELU(nn.Module): class LayerNorm(nn.LayerNorm): def forward(self, x): - return super().forward(x.float()).type_as(x) + return super().forward(x).type_as(x) class SelfAttention(nn.Module): @@ -256,15 +256,11 @@ class SelfAttention(nn.Module): """ x: [B, L, C]. """ - b, s, c, n, d = *x.size(), self.num_heads, self.head_dim - # compute query, key, value - q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + q, k, v = self.to_qkv(x).chunk(3, dim=-1) # compute attention - p = self.attn_dropout if self.training else 0.0 - x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) - x = x.reshape(b, s, c) + x = flash_attention(q, k, v, num_heads=self.num_heads) # output x = self.proj(x) @@ -371,11 +367,11 @@ class AttentionPool(nn.Module): b, s, c, n, d = *x.size(), self.num_heads, self.head_dim # compute query, key, value - q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) - k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1) + k, v = self.to_kv(x).chunk(2, dim=-1) # compute attention - x = flash_attention(q, k, v, version=2) + x = flash_attention(q, k, v, num_heads=self.num_heads) x = x.reshape(b, 1, c) # output @@ -878,6 +874,8 @@ class WanImageEncoder(torch.nn.Module): videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) # forward + dtype = next(iter(self.model.visual.parameters())).dtype + videos = videos.to(dtype) out = self.model.visual(videos, use_31_block=True) return out diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 01b5484..df23076 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -688,7 +688,7 @@ class WanVideoVAE(nn.Module): target_w: target_w + hidden_states_batch.shape[4], ] += mask values = values / weight - values = values.float().clamp_(-1, 1) + values = values.clamp_(-1, 1) return values @@ -740,20 +740,19 @@ class WanVideoVAE(nn.Module): target_w: target_w + hidden_states_batch.shape[4], ] += mask values = values / weight - values = values.float() return values def single_encode(self, video, device): video = video.to(device) x = self.model.encode(video, self.scale) - return x.float() + return x def single_decode(self, hidden_state, device): hidden_state = hidden_state.to(device) video = self.model.decode(hidden_state, self.scale) - return video.float().clamp_(-1, 1) + return video.clamp_(-1, 1) def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 45ef3b3..2f19d42 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -60,7 +60,6 @@ class WanVideoPipeline(BasePipeline): torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, - torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, }, module_config = dict( @@ -116,7 +115,7 @@ class WanVideoPipeline(BasePipeline): offload_device="cpu", onload_dtype=dtype, onload_device="cpu", - computation_dtype=self.torch_dtype, + computation_dtype=dtype, computation_device=self.device, ), ) @@ -153,17 +152,21 @@ class WanVideoPipeline(BasePipeline): def encode_image(self, image, num_frames, height, width): - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - image = self.preprocess_image(image.resize((width, height))).to(self.device) - clip_context = self.image_encoder.encode_image([image]) - msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) - msk[:, 1:] = 0 - msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) - msk = msk.transpose(1, 2)[0] - y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0] - y = torch.concat([msk, y]) - return {"clip_fea": clip_context, "y": [y]} + image = self.preprocess_image(image.resize((width, height))).to(self.device) + clip_context = self.image_encoder.encode_image([image]) + msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0] + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device) + y = y.to(dtype=self.torch_dtype, device=self.device) + return {"clip_feature": clip_context, "y": y} def tensor2video(self, frames): @@ -174,18 +177,16 @@ class WanVideoPipeline(BasePipeline): def prepare_extra_input(self, latents=None): - return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4} + return {} def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return latents def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames @@ -229,8 +230,8 @@ class WanVideoPipeline(BasePipeline): if input_video is not None: self.load_models_to_device(['vae']) input_video = self.preprocess_images(input_video) - input_video = torch.stack(input_video, dim=2) - latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device) + input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device) + latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = noise diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index c695622..f83e85e 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -113,6 +113,7 @@ class LightningModelForDataProcess(pl.LightningModule): self.pipe.device = self.device if video is not None: prompt_emb = self.pipe.encode_prompt(text) + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] data = {"latents": latents, "prompt_emb": prompt_emb} torch.save(data, path + ".tensors.pth") @@ -145,10 +146,21 @@ class TensorDataset(torch.utils.data.Dataset): class LightningModelForTrain(pl.LightningModule): - def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None): + def __init__( + self, + dit_path, + learning_rate=1e-5, + lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + pretrained_lora_path=None + ): super().__init__() model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models([dit_path]) + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -167,6 +179,7 @@ class LightningModelForTrain(pl.LightningModule): self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload def freeze_parameters(self): @@ -210,24 +223,25 @@ class LightningModelForTrain(pl.LightningModule): # Data latents = batch["latents"].to(self.device) prompt_emb = batch["prompt_emb"] - prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) # Loss + self.pipe.device = self.device noise = torch.randn_like(latents) timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) - timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) extra_input = self.pipe.prepare_extra_input(latents) noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) training_target = self.pipe.scheduler.training_target(latents, noise, timestep) # Compute loss - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - noise_pred = self.pipe.denoising_model()( - noisy_latents, timestep=timestep, **prompt_emb, **extra_input, - use_gradient_checkpointing=self.use_gradient_checkpointing - ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) - loss = loss * self.pipe.scheduler.training_weight(timestep) + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) # Record log self.log("train_loss", loss, prog_bar=True) @@ -410,6 +424,12 @@ def parse_args(): action="store_true", help="Whether to use gradient checkpointing.", ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) parser.add_argument( "--train_architecture", type=str, @@ -490,6 +510,7 @@ def train(args): lora_target_modules=args.lora_target_modules, init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, pretrained_lora_path=args.pretrained_lora_path, ) if args.use_swanlab: @@ -510,6 +531,7 @@ def train(args): max_epochs=args.max_epochs, accelerator="gpu", devices="auto", + precision="bf16", strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, diff --git a/examples/wanvideo/wan_14b_image_to_video.py b/examples/wanvideo/wan_14b_image_to_video.py index db4d6da..91894ae 100644 --- a/examples/wanvideo/wan_14b_image_to_video.py +++ b/examples/wanvideo/wan_14b_image_to_video.py @@ -11,7 +11,7 @@ snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1- model_manager = ModelManager(device="cpu") model_manager.load_models( ["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], - torch_dtype=torch.float16, # Image Encoder is loaded with float16 + torch_dtype=torch.float32, # Image Encoder is loaded with float32 ) model_manager.load_models( [ From e757013a142fe1a80172c57053da7cb6c26851d3 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 10 Mar 2025 17:47:14 +0800 Subject: [PATCH 3/4] vram optimization --- examples/wanvideo/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 51ceb3f..de3be03 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -155,6 +155,10 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ --use_gradient_checkpointing ``` +If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`. + +For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`. + Step 5: Test Test LoRA: From 718b45f2af8d37c5ce1775db9515867234ce7975 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 10 Mar 2025 18:25:23 +0800 Subject: [PATCH 4/4] bugfix --- diffsynth/models/wan_video_dit.py | 10 ++++++++-- diffsynth/models/wan_video_image_encoder.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index f1e5e47..32a79e3 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -24,8 +24,14 @@ except ModuleNotFoundError: SAGE_ATTN_AVAILABLE = False -def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int): - if FLASH_ATTN_3_AVAILABLE: +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py index b49235b..5ca878b 100644 --- a/diffsynth/models/wan_video_image_encoder.py +++ b/diffsynth/models/wan_video_image_encoder.py @@ -260,7 +260,7 @@ class SelfAttention(nn.Module): q, k, v = self.to_qkv(x).chunk(3, dim=-1) # compute attention - x = flash_attention(q, k, v, num_heads=self.num_heads) + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) # output x = self.proj(x) @@ -371,7 +371,7 @@ class AttentionPool(nn.Module): k, v = self.to_kv(x).chunk(2, dim=-1) # compute attention - x = flash_attention(q, k, v, num_heads=self.num_heads) + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) x = x.reshape(b, 1, c) # output