diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index de09d16..4f540f9 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -126,6 +126,7 @@ model_loader_configs = [ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), + (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index fa2b9f0..d9be8ab 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -223,7 +223,7 @@ class DiTBlock(nn.Module): class MLP(torch.nn.Module): - def __init__(self, in_dim, out_dim): + def __init__(self, in_dim, out_dim, has_pos_emb=False): super().__init__() self.proj = torch.nn.Sequential( nn.LayerNorm(in_dim), @@ -232,8 +232,13 @@ class MLP(torch.nn.Module): nn.Linear(in_dim, out_dim), nn.LayerNorm(out_dim) ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) return self.proj(x) @@ -266,6 +271,7 @@ class WanModel(torch.nn.Module): num_heads: int, num_layers: int, has_image_input: bool, + has_image_pos_emb: bool = False, ): super().__init__() self.dim = dim @@ -296,7 +302,8 @@ class WanModel(torch.nn.Module): self.freqs = precompute_freqs_cis_3d(head_dim) if has_image_input: - self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280 + self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + self.has_image_pos_emb = has_image_pos_emb def patchify(self, x: torch.Tensor): x = self.patch_embedding(x) @@ -552,6 +559,21 @@ class WanModelStateDictConverter: "num_layers": 40, "eps": 1e-6 } + elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_image_pos_emb": True + } else: config = {} return state_dict, config diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 1431a24..d9e535a 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -211,6 +211,8 @@ class WanVideoPipeline(BasePipeline): if end_image is not None: end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device) vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + if self.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1) msk[:, -1:] = 1 else: vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 3b03f86..f0e77f3 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -18,6 +18,7 @@ pip install -e . |Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)| |Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| |Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| +|Wan Team|14B first-last-frame-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|[wan_14B_flf2v.py](./wan_14B_flf2v.py)| |DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).| |DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).| |DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).| @@ -110,6 +111,12 @@ https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 +14B first-last-frame-to-video + +|First frame|Last frame|Video| +|-|-|-| +|![Image](https://github.com/user-attachments/assets/b0d8225b-aee0-4129-b8e5-58c8523221a6)|![Image](https://github.com/user-attachments/assets/2f0c9bc5-07e2-45fa-8320-53d63a4fd203)|https://github.com/user-attachments/assets/2a6a2681-622c-4512-b852-5f22e73830b1| + ## Train We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA: diff --git a/examples/wanvideo/wan_14B_flf2v.py b/examples/wanvideo/wan_14B_flf2v.py new file mode 100644 index 0000000..23109df --- /dev/null +++ b/examples/wanvideo/wan_14B_flf2v.py @@ -0,0 +1,52 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("Wan-AI/Wan2.1-FLF2V-14B-720P", local_dir="models/Wan-AI/Wan2.1-FLF2V-14B-720P") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + ["models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], + torch_dtype=torch.float32, # Image Encoder is loaded with float32 +) +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00001-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00002-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00003-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00004-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00005-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00006-of-00007.safetensors", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00007-of-00007.safetensors", + ], + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-FLF2V-14B-720P/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Download example image +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] +) + +# First and last frame to video +video = pipe( + prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=30, + input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), + end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), + height=960, width=960, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5)