diff --git a/README.md b/README.md index 11403a5..87bca13 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| @@ -317,7 +320,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History -- **July 11, 2025** 🔥🔥🔥 We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. +- **July 28, 2025** 🔥🔥🔥 With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/). + +- **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) - Github Repo: https://github.com/modelscope/Nexus-Gen - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) diff --git a/README_zh.md b/README_zh.md index 650d2ec..bfcaa46 100644 --- a/README_zh.md +++ b/README_zh.md @@ -167,6 +167,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| @@ -333,7 +336,9 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 -- **2025年7月11日** 🔥🔥🔥 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 +- **2025年7月28日** 🔥🔥🔥 Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 + +- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 - 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356) - Github 仓库: https://github.com/modelscope/Nexus-Gen - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b60c200..6a1c32f 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -58,7 +58,7 @@ from ..models.stepvideo_dit import StepVideoModel from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_image_encoder import WanImageEncoder -from ..models.wan_video_vae import WanVideoVAE +from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38 from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_vace import VaceWanModel @@ -141,6 +141,8 @@ model_loader_configs = [ (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"), (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"), (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), + (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), + (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), @@ -148,6 +150,7 @@ model_loader_configs = [ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), + (None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"), (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"), (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"), (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"), diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7ae3c50..d46eedf 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -426,7 +426,7 @@ class ModelManager: self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) - def fetch_model(self, model_name, file_path=None, require_model_path=False): + def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None): fetched_models = [] fetched_model_paths = [] for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): @@ -440,12 +440,25 @@ class ModelManager: return None if len(fetched_models) == 1: print(f"Using {model_name} from {fetched_model_paths[0]}.") + model = fetched_models[0] + path = fetched_model_paths[0] else: - print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + if index is None: + model = fetched_models[0] + path = fetched_model_paths[0] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + elif isinstance(index, int): + model = fetched_models[:index] + path = fetched_model_paths[:index] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.") + else: + model = fetched_models + path = fetched_model_paths + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.") if require_model_path: - return fetched_models[0], fetched_model_paths[0] + return model, path else: - return fetched_models[0] + return model def to(self, device): diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 50c06bf..ea473b0 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -212,9 +212,16 @@ class DiTBlock(nn.Module): self.gate = GateModule() def forward(self, x, context, t_mod, freqs): + has_seq = len(t_mod.shape) == 4 + chunk_dim = 2 if has_seq else 1 # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim) + if has_seq: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), + shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), + ) input_x = modulate(self.norm1(x), shift_msa, scale_msa) x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) x = x + self.cross_attn(self.norm3(x), context) @@ -253,8 +260,12 @@ class Head(nn.Module): self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) def forward(self, x, t_mod): - shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + scale) + shift)) + if len(t_mod.shape) == 3: + shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) + x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + else: + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) return x @@ -276,12 +287,20 @@ class WanModel(torch.nn.Module): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, + seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size + self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -659,6 +678,41 @@ class WanModelStateDictConverter: "add_control_adapter": True, "in_dim_control_adapter": 24, } + elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316": + # Wan-AI/Wan2.2-TI2V-5B + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 48, + "dim": 3072, + "ffn_dim": 14336, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 48, + "num_heads": 24, + "num_layers": 30, + "eps": 1e-6, + "seperated_timestep": True, + "require_clip_embedding": False, + "require_vae_embedding": False, + "fuse_vae_embedding_in_latents": True, + } + elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": + # Wan-AI/Wan2.2-I2V-A14B + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "require_clip_embedding": False, + } else: config = {} return state_dict, config diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 137fd28..397a2e7 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -195,6 +195,75 @@ class Resample(nn.Module): nn.init.zeros_(conv.bias.data) + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange(x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size) + return x + + +class Resample38(Resample): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super(Resample, self).__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + else: + self.resample = nn.Identity() + class ResidualBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout=0.0): @@ -273,6 +342,178 @@ class AttentionBlock(nn.Module): return x + identity +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample38(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + def __init__( + self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False + ): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample38(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + class Encoder3d(nn.Module): def __init__(self, @@ -376,6 +617,122 @@ class Encoder3d(nn.Module): return x +class Encoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] if i < len(temperal_downsample) else False + ) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + class Decoder3d(nn.Module): def __init__(self, @@ -481,10 +838,112 @@ class Decoder3d(nn.Module): return x + +class Decoder3d_38(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock(in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1)) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1)) + + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + def count_conv3d(model): count = 0 for m in model.modules(): - if check_is_instance(m, CausalConv3d): + if isinstance(m, CausalConv3d): count += 1 return count @@ -616,6 +1075,7 @@ class WanVideoVAE(nn.Module): # init model self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) self.upsampling_factor = 8 + self.z_dim = z_dim def build_1d_mask(self, length, left_bound, right_bound, border_width): @@ -711,7 +1171,7 @@ class WanVideoVAE(nn.Module): out_T = (T + 3) // 4 weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) - values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) @@ -762,8 +1222,8 @@ class WanVideoVAE(nn.Module): for video in videos: video = video.unsqueeze(0) if tiled: - tile_size = (tile_size[0] * 8, tile_size[1] * 8) - tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor) + tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor) hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) else: hidden_state = self.single_encode(video, device) @@ -798,3 +1258,119 @@ class WanVideoVAEStateDictConverter: for name in state_dict: state_dict_['model.' + name] = state_dict[name] return state_dict_ + + +class VideoVAE38_(VideoVAE_): + + def __init__(self, + dim=160, + z_dim=48, + dec_dim=256, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super(VideoVAE_, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d_38(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d_38(dec_dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + +class WanVideoVAE38(WanVideoVAE): + + def __init__(self, z_dim=48, dim=160): + super(WanVideoVAE, self).__init__() + + mean = [ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667 + ] + std = [ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) + self.upsampling_factor = 16 + self.z_dim = z_dim diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 21b8ba0..a425da2 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -39,17 +39,21 @@ class WanVideoPipeline(BasePipeline): self.text_encoder: WanTextEncoder = None self.image_encoder: WanImageEncoder = None self.dit: WanModel = None + self.dit2: WanModel = None self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), - WanVideoUnit_ImageEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -69,7 +73,9 @@ class WanVideoPipeline(BasePipeline): def training_loss(self, **inputs): - timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) @@ -141,6 +147,37 @@ class WanVideoPipeline(BasePipeline): ), vram_limit=vram_limit, ) + if self.dit2 is not None: + dtype = next(iter(self.dit2.parameters())).dtype + device = "cpu" if vram_limit is not None else self.device + enable_vram_management( + self.dit2, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: WanAutoCastLayerNorm, + RMSNorm: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) if self.vae is not None: dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( @@ -239,6 +276,10 @@ class WanVideoPipeline(BasePipeline): for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + if self.dit2 is not None: + for block in self.dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @@ -282,11 +323,20 @@ class WanVideoPipeline(BasePipeline): # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") - pipe.dit = model_manager.fetch_model("wan_video_dit") + dit = model_manager.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.dit, pipe.dit2 = dit + else: + pipe.dit = dit pipe.vae = model_manager.fetch_model("wan_video_vae") pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 # Initialize tokenizer tokenizer_config.download_if_necessary(use_usp=use_usp) @@ -333,6 +383,8 @@ class WanVideoPipeline(BasePipeline): # Classifier-free guidance cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, + # Boundary + switch_DiT_boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, @@ -385,8 +437,14 @@ class WanVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + + # Timestep timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - + # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: @@ -400,6 +458,8 @@ class WanVideoPipeline(BasePipeline): # Scheduler inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + if "first_frame_latents" in inputs_shared: + inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] # VACE (TODO: remove it) if vace_reference_image is not None: @@ -433,7 +493,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit): length = (num_frames - 1) // 4 + 1 if vace_reference_image is not None: length += 1 - noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device) + shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) if vace_reference_image is not None: noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) return {"noise": noise} @@ -482,6 +543,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit): class WanVideoUnit_ImageEmbedder(PipelineUnit): + """ + Deprecated + """ def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -489,7 +553,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None: + if input_image is None or pipe.image_encoder is None: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -517,13 +581,90 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"clip_feature": clip_context, "y": y} + + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "height", "width"), + onload_model_names=("image_encoder",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + + + +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] - + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + + + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): @@ -547,7 +688,7 @@ class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( input_params=("reference_image", "height", "width", "reference_image"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, reference_image, height, width): @@ -566,7 +707,8 @@ class WanVideoUnit_FunReference(PipelineUnit): class WanVideoUnit_FunCameraControl(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image") + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"), + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image): @@ -589,6 +731,7 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): input_image = input_image.resize((width, height)) input_latents = pipe.preprocess_video([input_image]) + pipe.load_models_to_device(self.onload_model_names) input_latents = pipe.vae.encode(input_latents, device=pipe.device) y = torch.zeros_like(latents).to(pipe.device) y[:, :, :1] = input_latents @@ -832,6 +975,7 @@ def model_fn_wan_video( use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -865,9 +1009,20 @@ def model_fn_wan_video( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) + + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) + t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) + else: + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) - t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + # Motion Controller if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -878,15 +1033,16 @@ def model_fn_wan_video( x = torch.concat([x] * context.shape[0], dim=0) if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) - - if dit.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) - # Reference image if reference_latents is not None: diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b171857..b27fc34 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -434,6 +434,8 @@ def wan_parser(): parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") return parser diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 2c811f5..8231e29 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -1,8 +1,8 @@ -# Wan 2.1 +# Wan [切换到中文](./README_zh.md) -Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba. +Wan is a collection of video synthesis models open-sourced by Alibaba. **DiffSynth-Studio has adopted a new inference and training framework. To use the previous version, please click [here](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c).** @@ -67,6 +67,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## Model Inference @@ -244,6 +247,7 @@ The pipeline accepts the following input parameters during inference: * `num_frames`: Number of frames, default is 81. Must be a multiple of 4 plus 1; if not, it will be rounded up, minimum is 1. * `cfg_scale`: Classifier-free guidance scale, default is 5. Higher values increase adherence to the prompt but may cause visual artifacts. * `cfg_merge`: Whether to merge both sides of classifier-free guidance for unified inference. Default is `False`. This parameter currently only works for basic text-to-video and image-to-video models. +* `switch_DiT_boundary`: The time point for switching between DiT models. Default value is 0.875. This parameter only takes effect for mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B). * `num_inference_steps`: Number of inference steps, default is 50. * `sigma_shift`: Parameter from Rectified Flow theory, default is 5. Higher values make the model stay longer at the initial denoising stage. Increasing this may improve video quality but may also cause inconsistency between generated videos and training data due to deviation from training behavior. * `motion_bucket_id`: Motion intensity, range [0, 100], applicable to motion control modules such as [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1). Larger values indicate more intense motion. @@ -279,6 +283,8 @@ The script includes the following parameters: * Models * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. + * `--max_timestep_boundary`: Maximum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B). + * `--min_timestep_boundary`: Minimum value of the timestep interval, ranging from 0 to 1. Default is 1. This needs to be manually set only when training mixed models with multiple DiTs, for example, [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B). * Training * `--learning_rate`: Learning rate. * `--num_epochs`: Number of epochs. diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 517cd9e..9f26858 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -1,8 +1,8 @@ -# 通义万相 2.1(Wan 2.1) +# 通义万相(Wan) [Switch to English](./README.md) -Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。 +Wan 是由阿里巴巴通义实验室开源的一系列视频生成模型。 **DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。** @@ -67,6 +67,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## 模型推理 @@ -245,6 +248,7 @@ Pipeline 在推理阶段能够接收以下输入参数: * `num_frames`: 帧数,默认为 81。需设置为 4 的倍数 + 1,不满足时向上取整,最小值为 1。 * `cfg_scale`: Classifier-free guidance 机制的数值,默认为 5。数值越大,提示词的控制效果越强,但画面崩坏的概率越大。 * `cfg_merge`: 是否合并 Classifier-free guidance 的两侧进行统一推理,默认为 `False`。该参数目前仅在基础的文生视频和图生视频模型上生效。 +* `switch_DiT_boundary`: 切换 DiT 模型的时间点,默认值为 0.875,仅对多 DiT 的混合模型生效,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。 * `num_inference_steps`: 推理次数,默认值为 50。 * `sigma_shift`: Rectified Flow 理论中的参数,默认为 5。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的视频内容与训练数据存在差异。 * `motion_bucket_id`: 运动幅度,范围为 [0, 100]。适用于速度控制模块,例如 [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1),数值越大,运动幅度越大。 @@ -281,6 +285,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 + * `--max_timestep_boundary`: Timestep 区间最大值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。 + * `--min_timestep_boundary`: Timestep 区间最小值,范围为 0~1,默认为 1,仅在多 DiT 的混合模型训练中需要手动设置,例如 [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)。 * 训练 * `--learning_rate`: 学习率。 * `--num_epochs`: 轮数(Epoch)。 diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..0c1be54 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) + +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..27b10d0 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -0,0 +1,24 @@ +import torch +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) +save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..fa91965 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -0,0 +1,43 @@ +import torch +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1248, + num_frames=121, +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Image-to-video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=704, width=1248, + input_image=input_image, + num_frames=121, +) +save_video(video, "video2.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..2f531e7 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh @@ -0,0 +1,35 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..f634117 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh @@ -0,0 +1,31 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..def9f89 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..4201b47 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh @@ -0,0 +1,37 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..737896c --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh @@ -0,0 +1,36 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..6a33b57 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 877c5b8..93ec0bd 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule): use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, + max_timestep_boundary=1.0, + min_timestep_boundary=0.0, ): super().__init__() # Load models @@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.max_timestep_boundary = max_timestep_boundary + self.min_timestep_boundary = min_timestep_boundary def forward_preprocess(self, data): @@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule): "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "cfg_merge": False, "vace_scale": 1, + "max_timestep_boundary": self.max_timestep_boundary, + "min_timestep_boundary": self.min_timestep_boundary, } # Extra inputs @@ -106,6 +112,8 @@ if __name__ == "__main__": lora_rank=args.lora_rank, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, + max_timestep_boundary=args.max_timestep_boundary, + min_timestep_boundary=args.min_timestep_boundary, ) model_logger = ModelLogger( args.output_path, diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..3f6d253 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=True, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..be0e000 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..8a3d36c --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=True, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..f221ef7 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=True, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..ab43927 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..e5b16c8 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=True, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)