diff --git a/README.md b/README.md index 0fd7212..89035f0 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,9 @@ Until now, DiffSynth Studio has supported the following models: ## News -- **February 17, 2024** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/). +- **February 25, 2025** We support Wan-Video, a collection of video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). + +- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/). - **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/). - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097) @@ -118,7 +120,7 @@ cd DiffSynth-Studio pip install -e . ``` -Or install from pypi: +Or install from pypi (There is a delay in the update. If you want to experience the latest features, please do not use this installation method.): ``` pip install diffsynth diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 9ac7e8e..9f8eb27 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -54,6 +54,11 @@ from ..models.hunyuan_video_dit import HunyuanVideoDiT from ..models.stepvideo_vae import StepVideoVAE from ..models.stepvideo_dit import StepVideoModel +from ..models.wan_video_dit import WanModel +from ..models.wan_video_text_encoder import WanTextEncoder +from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vae import WanVideoVAE + model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -108,6 +113,13 @@ model_loader_configs = [ (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"), (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"), + (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), + (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), + (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), + (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), + (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/kolors_text_encoder.py b/diffsynth/models/kolors_text_encoder.py index 55b24d2..693f72e 100644 --- a/diffsynth/models/kolors_text_encoder.py +++ b/diffsynth/models/kolors_text_encoder.py @@ -73,7 +73,6 @@ try: ) except Exception as exception: kernels = None - logger.warning("Failed to load cpm_kernels:" + str(exception)) class W8A16Linear(torch.autograd.Function): diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 81a9a8a..784ec7a 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -8,6 +8,7 @@ from .flux_dit import FluxDiT from .hunyuan_dit import HunyuanDiT from .cog_dit import CogDiT from .hunyuan_video_dit import HunyuanVideoDiT +from .wan_video_dit import WanModel @@ -197,7 +198,7 @@ class FluxLoRAFromCivitai(LoRAFromCivitai): class GeneralLoRAFromPeft: def __init__(self): - self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT] + self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel] def fetch_device_dtype_from_state_dict(self, state_dict): diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 8216af0..775fafe 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -69,7 +69,9 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re model_state_dict, extra_kwargs = state_dict_results, {} torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype with init_weights_on_device(): - model= model_class(**extra_kwargs) + model = model_class(**extra_kwargs) + if hasattr(model, "eval"): + model = model.eval() model.load_state_dict(model_state_dict, assign=True) model = model.to(dtype=torch_dtype, device=device) loaded_model_names.append(model_name) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py new file mode 100644 index 0000000..7c31b55 --- /dev/null +++ b/diffsynth/models/wan_video_dit.py @@ -0,0 +1,789 @@ +import math + +import torch +import torch.amp as amp +import torch.nn as nn +from tqdm import tqdm +from .utils import hash_state_dict_keys + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +import warnings + + +__all__ = ['WanModel'] + + +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + """ + q: [B, Lq, Nq, C1]. + k: [B, Lk, Nk, C1]. + v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. + q_lens: [B]. + k_lens: [B]. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + causal: bool. Whether to apply causal attention mask. + window_size: (left right). If not (-1, -1), apply sliding window local attention. + deterministic: bool. If True, slightly slower and uses more memory. + dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + """ + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.device.type == 'cuda' and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # preprocess query + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor( + [lq] * b, dtype=torch.int32).to( + device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # preprocess key, value + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor( + [lk] * b, dtype=torch.int32).to( + device=k.device, non_blocking=True) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + + if q_scale is not None: + q = q * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn( + 'Flash attention 3 is not available, use flash attention 2 instead.' + ) + + # apply attention + if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + # Note: dropout_p, window_size are not supported in FA3 now. + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic)[0].unflatten(0, (b, lq)) + elif FLASH_ATTN_2_AVAILABLE: + print(q_lens, lq, k_lens, lk, causal, window_size) + x = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic).unflatten(0, (b, lq)) + print(x.shape) + else: + q = q.unsqueeze(0).transpose(1, 2).to(dtype) + k = k.unsqueeze(0).transpose(1, 2).to(dtype) + v = v.unsqueeze(0).transpose(1, 2).to(dtype) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = x.transpose(1, 2).contiguous() + + # output + return x.type(out_dtype) + + +def create_sdpa_mask(q, k, q_lens, k_lens, causal=False): + b, lq, lk = q.size(0), q.size(1), k.size(1) + if q_lens is None: + q_lens = torch.tensor([lq] * b, dtype=torch.int32) + if k_lens is None: + k_lens = torch.tensor([lk] * b, dtype=torch.int32) + attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool) + for i in range(b): + q_len, k_len = q_lens[i], k_lens[i] + attn_mask[i, q_len:, :] = True + attn_mask[i, :, k_len:] = True + + if causal: + causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1) + attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask) + + attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True) + return attn_mask + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, +): + if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=fa_version, + ) + else: + if q_lens is not None or k_lens is not None: + warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.') + attn_mask = None + + q = q.transpose(1, 2).to(dtype) + k = k.transpose(1, 2).to(dtype) + v = v.transpose(1, 2).to(dtype) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + out = out.transpose(1, 2).contiguous() + return out + + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False, device_type="cuda") +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@amp.autocast(enabled=False, device_type="cuda") +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( + seq_len, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens): + """ + x: [B, L1, C]. + context: [B, L2, C]. + context_lens: [B]. + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm( + dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + """ + x: [B, L1, C]. + context: [B, L2, C]. + context_lens: [B]. + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + img_x = flash_attention(q, k_img, v_img, k_lens=None) + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +WANX_CROSSATTENTION_CLASSES = { + 't2v_cross_attn': WanT2VCrossAttention, + 'i2v_cross_attn': WanI2VCrossAttention, +} + + +class WanAttentionBlock(nn.Module): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, + eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type]( + dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + ): + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32, device_type="cuda"): + e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, + freqs) + with amp.autocast(dtype=torch.float32, device_type="cuda"): + x = x + y * e[2] + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + with amp.autocast(dtype=torch.float32, device_type="cuda"): + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32, device_type="cuda"): + e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim)) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(nn.Module): + + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__() + + assert model_type in ['t2v', 'i2v'] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + + if model_type == 'i2v': + self.img_emb = MLPProj(1280, dim) + + # initialize weights + self.init_weights() + + def forward( + self, + x, + timestep, + context, + seq_len, + clip_fea=None, + y=None, + use_gradient_checkpointing=False, + **kwargs, + ): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = x[0].device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32, device_type="cuda"): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, **kwargs, + use_reentrant=False, + ) + else: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + x = torch.stack(x).float() + return x + + def unpatchify(self, x, grid_sizes): + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) + + @staticmethod + def state_dict_converter(): + return WanModelStateDictConverter() + + +class WanModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": + config = { + "model_type": "t2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 16, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": + config = { + "model_type": "t2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = { + "model_type": "i2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + else: + config = {} + return state_dict, config diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py new file mode 100644 index 0000000..35f5ea3 --- /dev/null +++ b/diffsynth/models/wan_video_image_encoder.py @@ -0,0 +1,904 @@ +""" +Concise re-implementation of +``https://github.com/openai/CLIP'' and +``https://github.com/mlfoundations/open_clip''. +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from .wan_video_dit import flash_attention + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init model + if pretrained: + from sora import DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = XLMRoberta(**cfg) + + # load checkpoint + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'), + map_location=device), + assign=True) + else: + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + + # init tokenizer + if return_tokenizer: + from sora.data import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.text_len, + clean='whitespace') + return model, tokenizer + else: + return model + + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat([ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + 0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode='bicubic', + align_corners=False).flatten(2).transpose(1, 2) + ], + dim=1) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + causal=False, + attn_dropout=0.0, + proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + p = self.attn_dropout if self.training else 0.0 + x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5): + assert activation in ['quick_gelu', 'gelu', 'swi_glu'] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, + proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == 'swi_glu': + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + activation='gelu', + proj_dropout=0.0, + norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + x = flash_attention(q, k, v, version=2) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type='token', + pre_norm=True, + post_norm=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + if image_size % patch_size != 0: + print( + '[WARNING] image_size is not divisible by patch_size', + flush=True) + assert pool_type in ('token', 'token_fc', 'attn_pool') + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size)**2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm) + if pool_type in ('token', 'token_fc'): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(gain * torch.randn( + 1, self.num_patches + + (1 if pool_type in ('token', 'token_fc') else 0), dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, + activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == 'token': + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == 'token_fc': + self.head = nn.Linear(dim, out_dim) + elif pool_type == 'attn_pool': + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, + proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ('token', 'token_fc'): + x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + e = e.to(dtype=x.dtype, device=x.device) + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_mlp_ratio=4, + vision_heads=12, + vision_layers=12, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + vocab_size=49408, + text_len=77, + text_dim=512, + text_mlp_ratio=4, + text_heads=8, + text_layers=12, + text_causal=True, + text_pool='argmax', + text_head_bias=False, + logit_bias=None, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pool = vision_pool + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_mlp_ratio = text_mlp_ratio + self.text_heads = text_heads + self.text_layers = text_layers + self.text_causal = text_causal + self.text_pool = text_pool + self.text_head_bias = text_head_bias + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + mlp_ratio=text_mlp_ratio, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + causal=text_causal, + pool_type=text_pool, + head_bias=text_head_bias, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + if logit_bias is not None: + self.logit_bias = nn.Parameter(logit_bias * torch.ones([])) + + # initialize weights + self.init_weights() + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, std=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else self.text_dim + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * len(transformer))) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop('out_dim') + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__(self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = None + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +def _clip(pretrained=False, + pretrained_name=None, + model_cls=CLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding='eos', + dtype=torch.float32, + device='cpu', + **kwargs): + # init model + if pretrained and pretrained_name: + from sora import BUCKET, DOWNLOAD_TO_CACHE + + # init a meta model + with torch.device('meta'): + model = model_cls(**kwargs) + + # checkpoint path + checkpoint = f'models/clip/{pretrained_name}' + if dtype in (torch.float16, torch.bfloat16): + suffix = '-' + { + torch.float16: 'fp16', + torch.bfloat16: 'bf16' + }[dtype] + if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'): + checkpoint = f'{checkpoint}{suffix}' + checkpoint += '.pth' + + # load + model.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device), + assign=True, + strict=False) + else: + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if 'siglip' in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([ + T.Resize((model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std) + ]) + output += (transforms,) + + # init tokenizer + if return_tokenizer: + from sora import data + if 'siglip' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name=f'timm/{pretrained_name}', + seq_len=model.text_len, + clean='canonicalize') + elif 'xlm' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='xlm-roberta-large', + seq_len=model.max_text_len - 2, + clean='whitespace') + elif 'mba' in pretrained_name.lower(): + tokenizer = data.HuggingfaceTokenizer( + name='facebook/xlm-roberta-xl', + seq_len=model.max_text_len - 2, + clean='whitespace') + else: + tokenizer = data.CLIPTokenizer( + seq_len=model.text_len, padding=tokenizer_padding) + output += (tokenizer,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', + **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class WanImageEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + dtype=torch.float32, + device="cpu") + + def encode_image(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([ + F.interpolate( + u, + size=size, + mode='bicubic', + align_corners=False) for u in videos + ]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + out = self.model.visual(videos, use_31_block=True) + return out + + @staticmethod + def state_dict_converter(): + return WanImageEncoderStateDictConverter() + + +class WanImageEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith("textual."): + continue + name = "model." + name + state_dict_[name] = param + return state_dict_ + diff --git a/diffsynth/models/wan_video_text_encoder.py b/diffsynth/models/wan_video_text_encoder.py new file mode 100644 index 0000000..c288737 --- /dev/null +++ b/diffsynth/models/wan_video_text_encoder.py @@ -0,0 +1,269 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class WanTextEncoder(torch.nn.Module): + + def __init__(self, + vocab=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + num_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1): + super(WanTextEncoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + @staticmethod + def state_dict_converter(): + return WanTextEncoderStateDictConverter() + + +class WanTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py new file mode 100644 index 0000000..ebbee9d --- /dev/null +++ b/diffsynth/models/wan_video_vae.py @@ -0,0 +1,807 @@ +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +CACHE_T = 2 + + +def check_is_instance(model, module_class): + if isinstance(model, module_class): + return True + if hasattr(model, "module") and isinstance(model.module, module_class): + return True + return False + + +def block_causal_mask(x, block_size): + # params + b, n, s, _, device = *x.size(), x.device + assert s % block_size == 0 + num_blocks = s // block_size + + # build mask + mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device) + for i in range(num_blocks): + mask[:, :, + i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1 + return mask + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d(dim, + dim * 2, (3, 1, 1), + padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, + dim, (3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute( + 0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + #attn_mask=block_causal_mask(q, block_size=h * w) + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if check_is_instance(m, CausalConv3d): + count += 1 + return count + + +class VideoVAE_(nn.Module): + + def __init__(self, + dim=96, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +class WanVideoVAE(nn.Module): + + def __init__(self, z_dim=16): + super().__init__() + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) + self.upsampling_factor = 8 + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = T * 4 - 3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.float().clamp_(-1, 1) + return values + + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.float() + return values + + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x.float() + + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.float().clamp_(-1, 1) + + + def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + tile_size = (tile_size[0] * 8, tile_size[1] * 8) + tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + return videos + + + @staticmethod + def state_dict_converter(): + return WanVideoVAEStateDictConverter() + + +class WanVideoVAEStateDictConverter: + + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index 9a12372..e2ad551 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -11,4 +11,5 @@ from .omnigen_image import OmnigenImagePipeline from .pipeline_runner import SDVideoPipelineRunner from .hunyuan_video import HunyuanVideoPipeline from .step_video import StepVideoPipeline +from .wan_video import WanVideoPipeline KolorsImagePipeline = SDXLImagePipeline diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py new file mode 100644 index 0000000..8d50bad --- /dev/null +++ b/diffsynth/pipelines/wan_video.py @@ -0,0 +1,267 @@ +from ..models import ModelManager +from ..models.wan_video_dit import WanModel +from ..models.wan_video_text_encoder import WanTextEncoder +from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_image_encoder import WanImageEncoder +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline +from ..prompters import WanPrompter +import torch, os +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm + +from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear +from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm +from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm +from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample + + + +class WanVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.prompter = WanPrompter(tokenizer_path=tokenizer_path) + self.text_encoder: WanTextEncoder = None + self.image_encoder: WanImageEncoder = None + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.model_names = ['text_encoder', 'dit', 'vae'] + + + def enable_vram_management(self, num_persistent_param_in_dit=None): + dtype = next(iter(self.text_encoder.parameters())).dtype + enable_vram_management( + self.text_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Embedding: AutoWrappedModule, + T5RelativeEmbedding: AutoWrappedModule, + T5LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.dit.parameters())).dtype + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + WanLayerNorm: AutoWrappedModule, + WanRMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.vae.parameters())).dtype + enable_vram_management( + self.vae, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + RMS_norm: AutoWrappedModule, + CausalConv3d: AutoWrappedModule, + Upsample: AutoWrappedModule, + torch.nn.SiLU: AutoWrappedModule, + torch.nn.Dropout: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + if self.image_encoder is not None: + dtype = next(iter(self.image_encoder.parameters())).dtype + enable_vram_management( + self.image_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + + def fetch_models(self, model_manager: ModelManager): + text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True) + if text_encoder_model_and_path is not None: + self.text_encoder, tokenizer_path = text_encoder_model_and_path + self.prompter.fetch_models(self.text_encoder) + self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl")) + self.dit = model_manager.fetch_model("wan_video_dit") + self.vae = model_manager.fetch_model("wan_video_vae") + self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") + + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + return pipe + + + def denoising_model(self): + return self.dit + + + def encode_prompt(self, prompt, positive=True): + prompt_emb = self.prompter.encode_prompt(prompt, positive=positive) + return {"context": prompt_emb} + + + def encode_image(self, image, height, width): + with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): + image = self.preprocess_image(image.resize((width, height))).to(self.device) + clip_context = self.image_encoder.encode_image([image]) + msk = torch.ones(1, 81, height//8, width//8, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, 80, height, width).to(image.device)], dim=1)], device=self.device)[0] + y = torch.concat([msk, y]) + return {"clip_fea": clip_context, "y": [y]} + + + def tensor2video(self, frames): + frames = rearrange(frames, "C T H W -> T H W C") + frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) + frames = [Image.fromarray(frame) for frame in frames] + return frames + + + def prepare_extra_input(self, latents=None): + return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4} + + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + + def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return frames + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + input_image=None, + input_video=None, + denoising_strength=1.0, + seed=None, + rand_device="cpu", + height=480, + width=832, + num_frames=81, + cfg_scale=5.0, + num_inference_steps=50, + tiled=True, + tile_size=(34, 34), + tile_stride=(18, 16), + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Tiler parameters + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Initialize noise + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device) + if input_video is not None: + self.load_models_to_device(['vae']) + input_video = self.preprocess_images(input_video) + input_video = torch.stack(input_video, dim=2) + latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = noise + + # Encode prompts + self.load_models_to_device(["text_encoder"]) + prompt_emb_posi = self.encode_prompt(prompt, positive=True) + if cfg_scale != 1.0: + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) + + # Encode image + if input_image is not None and self.image_encoder is not None: + self.load_models_to_device(["image_encoder", "vae"]) + image_emb = self.encode_image(input_image, height, width) + else: + image_emb = {} + + # Extra input + extra_input = self.prepare_extra_input(latents) + + # Denoise + self.load_models_to_device(["dit"]) + with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device) + + # Inference + noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input) + if cfg_scale != 1.0: + noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # Scheduler + latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + + # Decode + self.load_models_to_device(['vae']) + frames = self.decode_video(latents, **tiler_kwargs) + self.load_models_to_device([]) + frames = self.tensor2video(frames[0]) + + return frames diff --git a/diffsynth/prompters/__init__.py b/diffsynth/prompters/__init__.py index 3559f1d..f27c6f1 100644 --- a/diffsynth/prompters/__init__.py +++ b/diffsynth/prompters/__init__.py @@ -9,3 +9,4 @@ from .omost import OmostPromter from .cog_prompter import CogPrompter from .hunyuan_video_prompter import HunyuanVideoPrompter from .stepvideo_prompter import StepVideoPrompter +from .wan_prompter import WanPrompter diff --git a/diffsynth/prompters/wan_prompter.py b/diffsynth/prompters/wan_prompter.py new file mode 100644 index 0000000..d2c578d --- /dev/null +++ b/diffsynth/prompters/wan_prompter.py @@ -0,0 +1,103 @@ +from .base_prompter import BasePrompter +from ..models.wan_video_text_encoder import WanTextEncoder +from transformers import AutoTokenizer +import os, torch +import html +import string +import regex as re + + +def basic_clean(text): + text = html.unescape(html.unescape(text)) + return text.strip() + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text + + +class WanPrompter(BasePrompter): + + def __init__(self, tokenizer_path=None, text_len=512): + super().__init__() + self.text_len = text_len + self.text_encoder = None + self.fetch_tokenizer(tokenizer_path) + + def fetch_tokenizer(self, tokenizer_path=None): + if tokenizer_path is not None: + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace') + + def fetch_models(self, text_encoder: WanTextEncoder = None): + self.text_encoder = text_encoder + + def encode_prompt(self, prompt, positive=True, device="cuda"): + prompt = self.process_prompt(prompt, positive=positive) + + ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = self.text_encoder(ids, mask) + prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)] + return prompt_emb diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md new file mode 100644 index 0000000..e4afaaa --- /dev/null +++ b/examples/wanvideo/README.md @@ -0,0 +1,144 @@ +# Wan-Video + +Wan-Video is a collection of video synthesis models open-sourced by Alibaba. + +## Inference + +### Wan-Video-1.3B-T2V + +Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py). + +Required VRAM: 6G + +https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8 + +Put sunglasses on the dog. + +https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb + +### Wan-Video-14B-T2V + +Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). + +We present a detailed table here. The model is tested on a single A100. + +|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting| +|-|-|-|-|-| +|torch.bfloat16|None (unlimited)|18.5s/it|40G|| +|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G|| +|torch.bfloat16|0|23.4s/it|10G|| +|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes| +|torch.float8_e4m3fn|0|24.0s/it|10G|| + +https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f + +### Wan-Video-14B-I2V + +Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py). + +![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39) + +https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 + +## Train + +We support Wan-Video LoRA training. Here is a tutorial. + +Step 1: Install additional packages + +``` +pip install peft lightning pandas +``` + +Step 2: Prepare your dataset + +You need to manage the training videos as follows: + +``` +data/example_dataset/ +├── metadata.csv +└── train + ├── video_00001.mp4 + └── video_00002.mp4 +``` + +`metadata.csv`: + +``` +file_name,text +video_00001.mp4,"video description" +video_00001.mp4,"video description" +``` + +Step 3: Data process + +```shell +CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ + --task data_process \ + --dataset_path data/example_dataset \ + --output_path ./models \ + --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" \ + --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" \ + --tiled \ + --num_frames 81 \ + --height 480 \ + --width 832 +``` + +After that, some cached files will be stored in the dataset folder. + +``` +data/example_dataset/ +├── metadata.csv +└── train + ├── video_00001.mp4 + ├── video_00001.mp4.tensors.pth + ├── video_00002.mp4 + └── video_00002.mp4.tensors.pth +``` + +Step 4: Train + +```shell +CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ + --task train \ + --dataset_path data/example_dataset \ + --output_path ./models \ + --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ + --steps_per_epoch 500 \ + --max_epochs 10 \ + --learning_rate 1e-4 \ + --lora_rank 4 \ + --lora_alpha 4 \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --accumulate_grad_batches 1 \ + --use_gradient_checkpointing +``` + +Step 5: Test + +```python +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData + + +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") +model_manager.load_models([ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", +]) +model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0) + +pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Text-to-video +video = pipe( + prompt="...", + negative_prompt="...", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video_with_lora.mp4", fps=30, quality=5) +``` diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py new file mode 100644 index 0000000..8bc2134 --- /dev/null +++ b/examples/wanvideo/train_wan_t2v.py @@ -0,0 +1,460 @@ +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoPipeline, ModelManager +from peft import LoraConfig, inject_adapter_in_model +import torchvision +from PIL import Image + + + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + + self.frame_process = v2.Compose([ + v2.Resize(size=(height, width), antialias=True), + v2.CenterCrop(size=(height, width)), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + return frames + + + def load_video(self, file_path): + start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames + + + def load_text_video_raw_data(self, data_id): + text = self.path[data_id] + video = self.load_video(self.path[data_id]) + data = {"text": text, "video": video} + return data + + + def __getitem__(self, data_id): + text = self.path[data_id] + path = self.path[data_id] + video = self.load_video(path) + data = {"text": text, "video": video, "path": path} + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, vae_path]) + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + def test_step(self, batch, batch_idx): + text, video, path = batch["text"][0], batch["video"], batch["path"][0] + self.pipe.device = self.device + if video is not None: + prompt_emb = self.pipe.encode_prompt(text) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + data = {"latents": latents, "prompt_emb": prompt_emb} + torch.save(data, path + ".tensors.pth") + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + + self.steps_per_epoch = steps_per_epoch + + + def __getitem__(self, index): + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path = self.path[data_id] + data = torch.load(path, weights_only=True, map_location="cpu") + return data + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + self.freeze_parameters() + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + ) + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"): + # Add LoRA to UNet + self.lora_alpha = lora_alpha + if init_lora_weights == "kaiming": + init_lora_weights = True + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights=init_lora_weights, + target_modules=lora_target_modules.split(","), + ) + model = inject_adapter_in_model(lora_config, model) + for param in model.parameters(): + # Upcast LoRA parameters into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)] + + # Loss + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + use_gradient_checkpointing=self.use_gradient_checkpointing + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + lora_state_dict = {} + for name, param in state_dict.items(): + if name in trainable_param_names: + lora_state_dict[name] = param + checkpoint.update(lora_state_dict) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--task", + type=str, + default="data_process", + required=True, + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default=None, + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default="q,k,v,o,ffn.0,ffn.2", + help="Layers with LoRA modules.", + ) + parser.add_argument( + "--init_lora_weights", + type=str, + default="kaiming", + choices=["gaussian", "kaiming"], + help="The initializing method of LoRA weight.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="auto", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="The dimension of the LoRA update matrices.", + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=4.0, + help="The weight of the LoRA update matrices.", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + use_gradient_checkpointing=args.use_gradient_checkpointing + ) + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args) diff --git a/examples/wanvideo/wan_1.3b_text_to_video.py b/examples/wanvideo/wan_1.3b_text_to_video.py new file mode 100644 index 0000000..0f4f1d9 --- /dev/null +++ b/examples/wanvideo/wan_1.3b_text_to_video.py @@ -0,0 +1,40 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", cache_dir="models") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Video-to-video +video = VideoData("video1.mp4", height=480, width=832) +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_video=video, denoising_strength=0.7, + num_inference_steps=50, + seed=1, tiled=True +) +save_video(video, "video2.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_image_to_video.py b/examples/wanvideo/wan_14b_image_to_video.py new file mode 100644 index 0000000..d3153fd --- /dev/null +++ b/examples/wanvideo/wan_14b_image_to_video.py @@ -0,0 +1,48 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", cache_dir="models") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors", + "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors", + ], + "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + "models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", + ], + torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. + +# Download example image +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# Image-to-video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py new file mode 100644 index 0000000..4c5a281 --- /dev/null +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -0,0 +1,38 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-14B", cache_dir="models") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00007.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00007.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00007.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00007.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00007.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00007.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00007-of-00007.safetensors", + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video1.mp4", fps=25, quality=5)