mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
Compare commits
9 Commits
v1.1.7
...
wan-models
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05094710e3 | ||
|
|
105eaf0f49 | ||
|
|
6cd032e846 | ||
|
|
9d8130b48d | ||
|
|
ce848a3d1a | ||
|
|
a8ce9fef33 | ||
|
|
8da0d183a2 | ||
|
|
4b2b3dda94 | ||
|
|
b1fabbc6b0 |
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
run: pip install wheel && pip install -r requirements.txt
|
||||||
- name: Build DiffSynth
|
- name: Build DiffSynth
|
||||||
run: python setup.py sdist bdist_wheel
|
run: python setup.py sdist bdist_wheel
|
||||||
- name: Publish package to PyPI
|
- name: Publish package to PyPI
|
||||||
|
|||||||
18
README.md
18
README.md
@@ -13,15 +13,9 @@ Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
|||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
Welcome to the magic world of Diffusion models!
|
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||||
|
|
||||||
DiffSynth consists of two open-source projects:
|
Until now, DiffSynth Studio has supported the following models:
|
||||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
|
||||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
|
||||||
|
|
||||||
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
|
||||||
|
|
||||||
Until now, DiffSynth-Studio has supported the following models:
|
|
||||||
|
|
||||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||||
@@ -42,11 +36,7 @@ Until now, DiffSynth-Studio has supported the following models:
|
|||||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||||
|
|
||||||
## News
|
## News
|
||||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||||
|
|
||||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
|
||||||
|
|
||||||
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
|
||||||
|
|
||||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||||
|
|
||||||
@@ -83,7 +73,7 @@ Until now, DiffSynth-Studio has supported the following models:
|
|||||||
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
||||||
- LoRA, ControlNet, and additional models will be available soon.
|
- LoRA, ControlNet, and additional models will be available soon.
|
||||||
|
|
||||||
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from ..models.flux_text_encoder import FluxTextEncoder2
|
|||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
from ..models.flux_controlnet import FluxControlNet
|
from ..models.flux_controlnet import FluxControlNet
|
||||||
from ..models.flux_ipadapter import FluxIpAdapter
|
from ..models.flux_ipadapter import FluxIpAdapter
|
||||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
|
||||||
|
|
||||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||||
from ..models.cog_dit import CogDiT
|
from ..models.cog_dit import CogDiT
|
||||||
@@ -59,7 +58,6 @@ from ..models.wan_video_dit import WanModel
|
|||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
from ..models.wan_video_vae import WanVideoVAE
|
from ..models.wan_video_vae import WanVideoVAE
|
||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
|
||||||
|
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
@@ -106,8 +104,6 @@ model_loader_configs = [
|
|||||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||||
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
|
||||||
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
|
||||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||||
@@ -121,16 +117,11 @@ model_loader_configs = [
|
|||||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -607,25 +598,6 @@ preset_models_on_modelscope = {
|
|||||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"InfiniteYou":{
|
|
||||||
"file_list":[
|
|
||||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
|
||||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
|
||||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
|
||||||
],
|
|
||||||
"load_path":[
|
|
||||||
[
|
|
||||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
|
||||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
|
||||||
],
|
|
||||||
"models/InfiniteYou/image_proj_model.bin",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
# ESRGAN
|
# ESRGAN
|
||||||
"ESRGAN_x4": [
|
"ESRGAN_x4": [
|
||||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||||
@@ -785,7 +757,6 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||||
"InfiniteYou",
|
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||||
"QwenPrompt",
|
"QwenPrompt",
|
||||||
"OmostPrompt",
|
"OmostPrompt",
|
||||||
|
|||||||
@@ -1,129 +0,0 @@
|
|||||||
import torch
|
|
||||||
from typing import Optional
|
|
||||||
from einops import rearrange
|
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|
||||||
get_sequence_parallel_world_size,
|
|
||||||
get_sp_group)
|
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
|
||||||
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
|
||||||
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
|
||||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
|
||||||
return x.to(position.dtype)
|
|
||||||
|
|
||||||
def pad_freqs(original_tensor, target_len):
|
|
||||||
seq_len, s1, s2 = original_tensor.shape
|
|
||||||
pad_size = target_len - seq_len
|
|
||||||
padding_tensor = torch.ones(
|
|
||||||
pad_size,
|
|
||||||
s1,
|
|
||||||
s2,
|
|
||||||
dtype=original_tensor.dtype,
|
|
||||||
device=original_tensor.device)
|
|
||||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
|
||||||
return padded_tensor
|
|
||||||
|
|
||||||
def rope_apply(x, freqs, num_heads):
|
|
||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
|
||||||
s_per_rank = x.shape[1]
|
|
||||||
|
|
||||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
|
||||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
|
||||||
|
|
||||||
sp_size = get_sequence_parallel_world_size()
|
|
||||||
sp_rank = get_sequence_parallel_rank()
|
|
||||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
|
||||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
|
||||||
|
|
||||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
|
||||||
return x_out.to(x.dtype)
|
|
||||||
|
|
||||||
def usp_dit_forward(self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
timestep: torch.Tensor,
|
|
||||||
context: torch.Tensor,
|
|
||||||
clip_feature: Optional[torch.Tensor] = None,
|
|
||||||
y: Optional[torch.Tensor] = None,
|
|
||||||
use_gradient_checkpointing: bool = False,
|
|
||||||
use_gradient_checkpointing_offload: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
t = self.time_embedding(
|
|
||||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
|
||||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
|
||||||
context = self.text_embedding(context)
|
|
||||||
|
|
||||||
if self.has_image_input:
|
|
||||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
|
||||||
clip_embdding = self.img_emb(clip_feature)
|
|
||||||
context = torch.cat([clip_embdding, context], dim=1)
|
|
||||||
|
|
||||||
x, (f, h, w) = self.patchify(x)
|
|
||||||
|
|
||||||
freqs = torch.cat([
|
|
||||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
|
||||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
|
||||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
|
||||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return module(*inputs)
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
# Context Parallel
|
|
||||||
x = torch.chunk(
|
|
||||||
x, get_sequence_parallel_world_size(),
|
|
||||||
dim=1)[get_sequence_parallel_rank()]
|
|
||||||
|
|
||||||
for block in self.blocks:
|
|
||||||
if self.training and use_gradient_checkpointing:
|
|
||||||
if use_gradient_checkpointing_offload:
|
|
||||||
with torch.autograd.graph.save_on_cpu():
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, freqs,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
x, context, t_mod, freqs,
|
|
||||||
use_reentrant=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = block(x, context, t_mod, freqs)
|
|
||||||
|
|
||||||
x = self.head(x, t)
|
|
||||||
|
|
||||||
# Context Parallel
|
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
|
||||||
|
|
||||||
# unpatchify
|
|
||||||
x = self.unpatchify(x, (f, h, w))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def usp_attn_forward(self, x, freqs):
|
|
||||||
q = self.norm_q(self.q(x))
|
|
||||||
k = self.norm_k(self.k(x))
|
|
||||||
v = self.v(x)
|
|
||||||
|
|
||||||
q = rope_apply(q, freqs, self.num_heads)
|
|
||||||
k = rope_apply(k, freqs, self.num_heads)
|
|
||||||
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
|
||||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
|
||||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
|
||||||
|
|
||||||
x = xFuserLongContextAttention()(
|
|
||||||
None,
|
|
||||||
query=q,
|
|
||||||
key=k,
|
|
||||||
value=v,
|
|
||||||
)
|
|
||||||
x = x.flatten(2)
|
|
||||||
|
|
||||||
del q, k, v
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return self.o(x)
|
|
||||||
@@ -5,7 +5,7 @@ import pathlib
|
|||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
# from turtle import forward
|
from turtle import forward
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -318,8 +318,6 @@ class FluxControlNetStateDictConverter:
|
|||||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||||
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
|
||||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
|
||||||
else:
|
else:
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
return state_dict_, extra_kwargs
|
return state_dict_, extra_kwargs
|
||||||
|
|||||||
@@ -1,128 +0,0 @@
|
|||||||
import math
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
# FFN
|
|
||||||
def FeedForward(dim, mult=4):
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
return nn.Sequential(
|
|
||||||
nn.LayerNorm(dim),
|
|
||||||
nn.Linear(dim, inner_dim, bias=False),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(inner_dim, dim, bias=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_tensor(x, heads):
|
|
||||||
bs, length, width = x.shape
|
|
||||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
|
||||||
x = x.view(bs, length, heads, -1)
|
|
||||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
|
||||||
x = x.reshape(bs, heads, length, -1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PerceiverAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = dim_head**-0.5
|
|
||||||
self.dim_head = dim_head
|
|
||||||
self.heads = heads
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
|
||||||
|
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
||||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x, latents):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): image features
|
|
||||||
shape (b, n1, D)
|
|
||||||
latent (torch.Tensor): latent features
|
|
||||||
shape (b, n2, D)
|
|
||||||
"""
|
|
||||||
x = self.norm1(x)
|
|
||||||
latents = self.norm2(latents)
|
|
||||||
|
|
||||||
b, l, _ = latents.shape
|
|
||||||
|
|
||||||
q = self.to_q(latents)
|
|
||||||
kv_input = torch.cat((x, latents), dim=-2)
|
|
||||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
||||||
|
|
||||||
q = reshape_tensor(q, self.heads)
|
|
||||||
k = reshape_tensor(k, self.heads)
|
|
||||||
v = reshape_tensor(v, self.heads)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
|
||||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
|
||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
out = weight @ v
|
|
||||||
|
|
||||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
|
||||||
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
|
|
||||||
class InfiniteYouImageProjector(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim=1280,
|
|
||||||
depth=4,
|
|
||||||
dim_head=64,
|
|
||||||
heads=20,
|
|
||||||
num_queries=8,
|
|
||||||
embedding_dim=512,
|
|
||||||
output_dim=4096,
|
|
||||||
ff_mult=4,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
|
||||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
|
||||||
|
|
||||||
self.proj_out = nn.Linear(dim, output_dim)
|
|
||||||
self.norm_out = nn.LayerNorm(output_dim)
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
for _ in range(depth):
|
|
||||||
self.layers.append(
|
|
||||||
nn.ModuleList([
|
|
||||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
|
||||||
FeedForward(dim=dim, mult=ff_mult),
|
|
||||||
]))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
|
||||||
|
|
||||||
x = self.proj_in(x)
|
|
||||||
|
|
||||||
for attn, ff in self.layers:
|
|
||||||
latents = attn(x, latents) + latents
|
|
||||||
latents = ff(latents) + latents
|
|
||||||
|
|
||||||
latents = self.proj_out(latents)
|
|
||||||
return self.norm_out(latents)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return FluxInfiniteYouImageProjectorStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxInfiniteYouImageProjectorStateDictConverter:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict['image_proj']
|
|
||||||
@@ -367,20 +367,5 @@ class FluxLoRAConverter:
|
|||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
class WanLoRAConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_opensource_format(state_dict, **kwargs):
|
|
||||||
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
|
||||||
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
204
diffsynth/models/wan_video_controlnet.py
Normal file
204
diffsynth/models/wan_video_controlnet.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from einops import rearrange
|
||||||
|
from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d
|
||||||
|
from .utils import hash_state_dict_keys
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WanControlNetModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
in_dim: int,
|
||||||
|
ffn_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
text_dim: int,
|
||||||
|
freq_dim: int,
|
||||||
|
eps: float,
|
||||||
|
patch_size: Tuple[int, int, int],
|
||||||
|
num_heads: int,
|
||||||
|
num_layers: int,
|
||||||
|
has_image_input: bool,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.freq_dim = freq_dim
|
||||||
|
self.has_image_input = has_image_input
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv3d(
|
||||||
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
self.text_embedding = nn.Sequential(
|
||||||
|
nn.Linear(text_dim, dim),
|
||||||
|
nn.GELU(approximate='tanh'),
|
||||||
|
nn.Linear(dim, dim)
|
||||||
|
)
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
nn.Linear(freq_dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim)
|
||||||
|
)
|
||||||
|
self.time_projection = nn.Sequential(
|
||||||
|
nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.freqs = precompute_freqs_cis_3d(head_dim)
|
||||||
|
|
||||||
|
if has_image_input:
|
||||||
|
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
|
||||||
|
|
||||||
|
self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([
|
||||||
|
torch.nn.Linear(dim, dim, bias=False)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def patchify(self, x: torch.Tensor):
|
||||||
|
x = self.patch_embedding(x)
|
||||||
|
grid_size = x.shape[2:]
|
||||||
|
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||||
|
return x, grid_size # x, grid_size: (f, h, w)
|
||||||
|
|
||||||
|
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||||
|
return rearrange(
|
||||||
|
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
||||||
|
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
||||||
|
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
controlnet_conditioning: Optional[torch.Tensor] = None,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
if self.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = self.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x = x + self.controlnet_conv_in(controlnet_conditioning)
|
||||||
|
x, (f, h, w) = self.patchify(x)
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
res_stack = []
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
res_stack.append(x)
|
||||||
|
|
||||||
|
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||||
|
return controlnet_res_stack
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanControlNetModelStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class WanControlNetModelStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_base_model(self, state_dict):
|
||||||
|
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||||
|
config = {
|
||||||
|
"has_image_input": False,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 16,
|
||||||
|
"dim": 1536,
|
||||||
|
"ffn_dim": 8960,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 30,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||||
|
config = {
|
||||||
|
"has_image_input": False,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 16,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||||
|
config = {
|
||||||
|
"has_image_input": True,
|
||||||
|
"patch_size": [1, 2, 2],
|
||||||
|
"in_dim": 36,
|
||||||
|
"dim": 5120,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"text_dim": 4096,
|
||||||
|
"out_dim": 16,
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"eps": 1e-6
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config = {}
|
||||||
|
state_dict_ = {}
|
||||||
|
dtype, device = None, None
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("head."):
|
||||||
|
continue
|
||||||
|
state_dict_[name] = param
|
||||||
|
dtype, device = param.dtype, param.device
|
||||||
|
for block_id in range(config["num_layers"]):
|
||||||
|
zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device)
|
||||||
|
state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone()
|
||||||
|
state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device)
|
||||||
|
state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device)
|
||||||
|
return state_dict_, config
|
||||||
@@ -183,13 +183,6 @@ class CrossAttention(nn.Module):
|
|||||||
return self.o(x)
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
class GateModule(nn.Module):
|
|
||||||
def __init__(self,):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x, gate, residual):
|
|
||||||
return x + gate * residual
|
|
||||||
|
|
||||||
class DiTBlock(nn.Module):
|
class DiTBlock(nn.Module):
|
||||||
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -206,17 +199,16 @@ class DiTBlock(nn.Module):
|
|||||||
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
||||||
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
||||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||||
self.gate = GateModule()
|
|
||||||
|
|
||||||
def forward(self, x, context, t_mod, freqs):
|
def forward(self, x, context, t_mod, freqs):
|
||||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
x = x + gate_msa * self.self_attn(input_x, freqs)
|
||||||
x = x + self.cross_attn(self.norm3(x), context)
|
x = x + self.cross_attn(self.norm3(x), context)
|
||||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
x = x + gate_mlp * self.ffn(input_x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -493,62 +485,6 @@ class WanModelStateDictConverter:
|
|||||||
"num_layers": 40,
|
"num_layers": 40,
|
||||||
"eps": 1e-6
|
"eps": 1e-6
|
||||||
}
|
}
|
||||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
|
||||||
config = {
|
|
||||||
"has_image_input": True,
|
|
||||||
"patch_size": [1, 2, 2],
|
|
||||||
"in_dim": 36,
|
|
||||||
"dim": 1536,
|
|
||||||
"ffn_dim": 8960,
|
|
||||||
"freq_dim": 256,
|
|
||||||
"text_dim": 4096,
|
|
||||||
"out_dim": 16,
|
|
||||||
"num_heads": 12,
|
|
||||||
"num_layers": 30,
|
|
||||||
"eps": 1e-6
|
|
||||||
}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
|
||||||
config = {
|
|
||||||
"has_image_input": True,
|
|
||||||
"patch_size": [1, 2, 2],
|
|
||||||
"in_dim": 36,
|
|
||||||
"dim": 5120,
|
|
||||||
"ffn_dim": 13824,
|
|
||||||
"freq_dim": 256,
|
|
||||||
"text_dim": 4096,
|
|
||||||
"out_dim": 16,
|
|
||||||
"num_heads": 40,
|
|
||||||
"num_layers": 40,
|
|
||||||
"eps": 1e-6
|
|
||||||
}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
|
||||||
config = {
|
|
||||||
"has_image_input": True,
|
|
||||||
"patch_size": [1, 2, 2],
|
|
||||||
"in_dim": 48,
|
|
||||||
"dim": 1536,
|
|
||||||
"ffn_dim": 8960,
|
|
||||||
"freq_dim": 256,
|
|
||||||
"text_dim": 4096,
|
|
||||||
"out_dim": 16,
|
|
||||||
"num_heads": 12,
|
|
||||||
"num_layers": 30,
|
|
||||||
"eps": 1e-6
|
|
||||||
}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
|
||||||
config = {
|
|
||||||
"has_image_input": True,
|
|
||||||
"patch_size": [1, 2, 2],
|
|
||||||
"in_dim": 48,
|
|
||||||
"dim": 5120,
|
|
||||||
"ffn_dim": 13824,
|
|
||||||
"freq_dim": 256,
|
|
||||||
"text_dim": 4096,
|
|
||||||
"out_dim": 16,
|
|
||||||
"num_heads": 40,
|
|
||||||
"num_layers": 40,
|
|
||||||
"eps": 1e-6
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
return state_dict, config
|
return state_dict, config
|
||||||
|
|||||||
@@ -25,20 +25,3 @@ class WanMotionControllerModel(torch.nn.Module):
|
|||||||
state_dict = self.linear[-1].state_dict()
|
state_dict = self.linear[-1].state_dict()
|
||||||
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||||
self.linear[-1].load_state_dict(state_dict)
|
self.linear[-1].load_state_dict(state_dict)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return WanMotionControllerModelDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WanMotionControllerModelDictConverter:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.controlnet: FluxMultiControlNetManager = None
|
self.controlnet: FluxMultiControlNetManager = None
|
||||||
self.ipadapter: FluxIpAdapter = None
|
self.ipadapter: FluxIpAdapter = None
|
||||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||||
self.infinityou_processor: InfinitYou = None
|
|
||||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||||
|
|
||||||
|
|
||||||
@@ -163,11 +162,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||||
|
|
||||||
# InfiniteYou
|
|
||||||
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
|
||||||
if self.image_proj_model is not None:
|
|
||||||
self.infinityou_processor = InfinitYou(device=self.device)
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||||
@@ -355,13 +349,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||||
|
|
||||||
|
|
||||||
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
|
|
||||||
if self.infinityou_processor is not None and id_image is not None:
|
|
||||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
|
||||||
else:
|
|
||||||
return {}, controlnet_image
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -395,9 +382,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
eligen_entity_masks=None,
|
eligen_entity_masks=None,
|
||||||
enable_eligen_on_negative=False,
|
enable_eligen_on_negative=False,
|
||||||
enable_eligen_inpaint=False,
|
enable_eligen_inpaint=False,
|
||||||
# InfiniteYou
|
|
||||||
infinityou_id_image=None,
|
|
||||||
infinityou_guidance=1.0,
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
# Tile
|
# Tile
|
||||||
@@ -425,9 +409,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
# Extra input
|
# Extra input
|
||||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||||
|
|
||||||
# InfiniteYou
|
|
||||||
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
|
|
||||||
|
|
||||||
# Entity control
|
# Entity control
|
||||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||||
|
|
||||||
@@ -449,7 +430,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||||
@@ -466,7 +447,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
noise_pred_nega = lets_dance_flux(
|
noise_pred_nega = lets_dance_flux(
|
||||||
dit=self.dit, controlnet=self.controlnet,
|
dit=self.dit, controlnet=self.controlnet,
|
||||||
hidden_states=latents, timestep=timestep,
|
hidden_states=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -488,58 +469,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InfinitYou:
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
|
||||||
from facexlib.recognition import init_recognition_model
|
|
||||||
from insightface.app import FaceAnalysis
|
|
||||||
self.device = device
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
insightface_root_path = 'models/InfiniteYou/insightface'
|
|
||||||
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
|
||||||
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
|
|
||||||
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
|
||||||
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
|
||||||
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
|
||||||
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
|
||||||
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
|
||||||
|
|
||||||
def _detect_face(self, id_image_cv2):
|
|
||||||
face_info = self.app_640.get(id_image_cv2)
|
|
||||||
if len(face_info) > 0:
|
|
||||||
return face_info
|
|
||||||
face_info = self.app_320.get(id_image_cv2)
|
|
||||||
if len(face_info) > 0:
|
|
||||||
return face_info
|
|
||||||
face_info = self.app_160.get(id_image_cv2)
|
|
||||||
return face_info
|
|
||||||
|
|
||||||
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
|
||||||
from insightface.utils import face_align
|
|
||||||
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
|
||||||
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
|
||||||
arc_face_image = 2 * arc_face_image - 1
|
|
||||||
arc_face_image = arc_face_image.contiguous().to(self.device)
|
|
||||||
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
|
||||||
return face_emb
|
|
||||||
|
|
||||||
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
|
|
||||||
import cv2
|
|
||||||
if id_image is None:
|
|
||||||
return {'id_emb': None}, controlnet_image
|
|
||||||
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
|
|
||||||
face_info = self._detect_face(id_image_cv2)
|
|
||||||
if len(face_info) == 0:
|
|
||||||
raise ValueError('No face detected in the input ID image')
|
|
||||||
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
|
||||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
|
||||||
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
|
||||||
if controlnet_image is None:
|
|
||||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
|
||||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
|
||||||
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
|
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
@@ -600,8 +529,6 @@ def lets_dance_flux(
|
|||||||
entity_prompt_emb=None,
|
entity_prompt_emb=None,
|
||||||
entity_masks=None,
|
entity_masks=None,
|
||||||
ipadapter_kwargs_list={},
|
ipadapter_kwargs_list={},
|
||||||
id_emb=None,
|
|
||||||
infinityou_guidance=None,
|
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -646,9 +573,6 @@ def lets_dance_flux(
|
|||||||
"tile_size": tile_size,
|
"tile_size": tile_size,
|
||||||
"tile_stride": tile_stride,
|
"tile_stride": tile_stride,
|
||||||
}
|
}
|
||||||
if id_emb is not None:
|
|
||||||
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
|
||||||
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
|
|
||||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||||
controlnet_frames, **controlnet_extra_kwargs
|
controlnet_frames, **controlnet_extra_kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import types
|
|
||||||
from ..models import ModelManager
|
from ..models import ModelManager
|
||||||
from ..models.wan_video_dit import WanModel
|
from ..models.wan_video_dit import WanModel
|
||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
@@ -18,6 +17,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
|
|||||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||||
|
from ..models.wan_video_controlnet import WanControlNetModel
|
||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
|
||||||
|
|
||||||
@@ -32,11 +32,11 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.image_encoder: WanImageEncoder = None
|
self.image_encoder: WanImageEncoder = None
|
||||||
self.dit: WanModel = None
|
self.dit: WanModel = None
|
||||||
self.vae: WanVideoVAE = None
|
self.vae: WanVideoVAE = None
|
||||||
|
self.controlnet: WanControlNetModel = None
|
||||||
self.motion_controller: WanMotionControllerModel = None
|
self.motion_controller: WanMotionControllerModel = None
|
||||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
|
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet', 'motion_controller']
|
||||||
self.height_division_factor = 16
|
self.height_division_factor = 16
|
||||||
self.width_division_factor = 16
|
self.width_division_factor = 16
|
||||||
self.use_unified_sequence_parallel = False
|
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
@@ -124,22 +124,6 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
computation_device=self.device,
|
computation_device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if self.motion_controller is not None:
|
|
||||||
dtype = next(iter(self.motion_controller.parameters())).dtype
|
|
||||||
enable_vram_management(
|
|
||||||
self.motion_controller,
|
|
||||||
module_map = {
|
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
|
||||||
},
|
|
||||||
module_config = dict(
|
|
||||||
offload_dtype=dtype,
|
|
||||||
offload_device="cpu",
|
|
||||||
onload_dtype=dtype,
|
|
||||||
onload_device="cpu",
|
|
||||||
computation_dtype=dtype,
|
|
||||||
computation_device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.enable_cpu_offload()
|
self.enable_cpu_offload()
|
||||||
|
|
||||||
|
|
||||||
@@ -152,24 +136,14 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||||
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
||||||
if device is None: device = model_manager.device
|
if device is None: device = model_manager.device
|
||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
pipe.fetch_models(model_manager)
|
pipe.fetch_models(model_manager)
|
||||||
if use_usp:
|
|
||||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
|
||||||
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
|
||||||
|
|
||||||
for block in pipe.dit.blocks:
|
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
|
||||||
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
|
||||||
pipe.sp_size = get_sequence_parallel_world_size()
|
|
||||||
pipe.use_unified_sequence_parallel = True
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -178,26 +152,20 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def encode_prompt(self, prompt, positive=True):
|
def encode_prompt(self, prompt, positive=True):
|
||||||
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
|
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
||||||
return {"context": prompt_emb}
|
return {"context": prompt_emb}
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, end_image, num_frames, height, width):
|
def encode_image(self, image, num_frames, height, width):
|
||||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||||
clip_context = self.image_encoder.encode_image([image])
|
clip_context = self.image_encoder.encode_image([image])
|
||||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||||
msk[:, 1:] = 0
|
msk[:, 1:] = 0
|
||||||
if end_image is not None:
|
|
||||||
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
|
|
||||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
|
||||||
msk[:, -1:] = 1
|
|
||||||
else:
|
|
||||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
|
||||||
|
|
||||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
msk = msk.transpose(1, 2)[0]
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
|
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
y = y.unsqueeze(0)
|
y = y.unsqueeze(0)
|
||||||
@@ -206,25 +174,6 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
return {"clip_feature": clip_context, "y": y}
|
return {"clip_feature": clip_context, "y": y}
|
||||||
|
|
||||||
|
|
||||||
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
control_video = self.preprocess_images(control_video)
|
|
||||||
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
|
||||||
if control_video is not None:
|
|
||||||
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
if clip_feature is None or y is None:
|
|
||||||
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
|
|
||||||
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
|
|
||||||
else:
|
|
||||||
y = y[:, -16:]
|
|
||||||
y = torch.concat([control_latents, y], dim=1)
|
|
||||||
return {"clip_feature": clip_feature, "y": y}
|
|
||||||
|
|
||||||
|
|
||||||
def tensor2video(self, frames):
|
def tensor2video(self, frames):
|
||||||
frames = rearrange(frames, "C T H W -> T H W C")
|
frames = rearrange(frames, "C T H W -> T H W C")
|
||||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||||
@@ -246,8 +195,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
def prepare_unified_sequence_parallel(self):
|
def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return {"controlnet_conditioning": controlnet_conditioning}
|
||||||
|
|
||||||
|
|
||||||
def prepare_motion_bucket_id(self, motion_bucket_id):
|
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||||
@@ -261,9 +211,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
prompt,
|
prompt,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
input_image=None,
|
input_image=None,
|
||||||
end_image=None,
|
|
||||||
input_video=None,
|
input_video=None,
|
||||||
control_video=None,
|
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
seed=None,
|
seed=None,
|
||||||
rand_device="cpu",
|
rand_device="cpu",
|
||||||
@@ -279,6 +227,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tile_stride=(15, 26),
|
tile_stride=(15, 26),
|
||||||
tea_cache_l1_thresh=None,
|
tea_cache_l1_thresh=None,
|
||||||
tea_cache_model_id="",
|
tea_cache_model_id="",
|
||||||
|
controlnet_frames=None,
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
@@ -315,14 +264,18 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Encode image
|
# Encode image
|
||||||
if input_image is not None and self.image_encoder is not None:
|
if input_image is not None and self.image_encoder is not None:
|
||||||
self.load_models_to_device(["image_encoder", "vae"])
|
self.load_models_to_device(["image_encoder", "vae"])
|
||||||
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
|
image_emb = self.encode_image(input_image, num_frames, height, width)
|
||||||
else:
|
else:
|
||||||
image_emb = {}
|
image_emb = {}
|
||||||
|
|
||||||
# ControlNet
|
# ControlNet
|
||||||
if control_video is not None:
|
if self.controlnet is not None and controlnet_frames is not None:
|
||||||
self.load_models_to_device(["image_encoder", "vae"])
|
self.load_models_to_device(['vae', 'controlnet'])
|
||||||
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
|
controlnet_frames = self.preprocess_images(controlnet_frames)
|
||||||
|
controlnet_frames = torch.stack(controlnet_frames, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
controlnet_kwargs = self.prepare_controlnet(controlnet_frames)
|
||||||
|
else:
|
||||||
|
controlnet_kwargs = {}
|
||||||
|
|
||||||
# Motion Controller
|
# Motion Controller
|
||||||
if self.motion_controller is not None and motion_bucket_id is not None:
|
if self.motion_controller is not None and motion_bucket_id is not None:
|
||||||
@@ -337,27 +290,24 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
# Unified Sequence Parallel
|
|
||||||
usp_kwargs = self.prepare_unified_sequence_parallel()
|
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(["dit", "motion_controller"])
|
self.load_models_to_device(["dit", "controlnet", "motion_controller"])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
noise_pred_posi = model_fn_wan_video(
|
noise_pred_posi = model_fn_wan_video(
|
||||||
self.dit, motion_controller=self.motion_controller,
|
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
|
||||||
x=latents, timestep=timestep,
|
x=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **image_emb, **extra_input,
|
**prompt_emb_posi, **image_emb, **extra_input,
|
||||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs
|
**tea_cache_posi, **controlnet_kwargs, **motion_kwargs,
|
||||||
)
|
)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = model_fn_wan_video(
|
noise_pred_nega = model_fn_wan_video(
|
||||||
self.dit, motion_controller=self.motion_controller,
|
self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
|
||||||
x=latents, timestep=timestep,
|
x=latents, timestep=timestep,
|
||||||
**prompt_emb_nega, **image_emb, **extra_input,
|
**prompt_emb_nega, **image_emb, **extra_input,
|
||||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs
|
**tea_cache_nega, **controlnet_kwargs, **motion_kwargs,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
@@ -431,6 +381,7 @@ class TeaCache:
|
|||||||
|
|
||||||
def model_fn_wan_video(
|
def model_fn_wan_video(
|
||||||
dit: WanModel,
|
dit: WanModel,
|
||||||
|
controlnet: WanControlNetModel = None,
|
||||||
motion_controller: WanMotionControllerModel = None,
|
motion_controller: WanMotionControllerModel = None,
|
||||||
x: torch.Tensor = None,
|
x: torch.Tensor = None,
|
||||||
timestep: torch.Tensor = None,
|
timestep: torch.Tensor = None,
|
||||||
@@ -438,15 +389,22 @@ def model_fn_wan_video(
|
|||||||
clip_feature: Optional[torch.Tensor] = None,
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
use_unified_sequence_parallel: bool = False,
|
controlnet_conditioning: Optional[torch.Tensor] = None,
|
||||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_unified_sequence_parallel:
|
# ControlNet
|
||||||
import torch.distributed as dist
|
if controlnet is not None and controlnet_conditioning is not None:
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
controlnet_res_stack = controlnet(
|
||||||
get_sequence_parallel_world_size,
|
x, timestep=timestep, context=context, clip_feature=clip_feature, y=y,
|
||||||
get_sp_group)
|
controlnet_conditioning=controlnet_conditioning,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
controlnet_res_stack = None
|
||||||
|
|
||||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||||
@@ -473,21 +431,37 @@ def model_fn_wan_video(
|
|||||||
else:
|
else:
|
||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
# blocks
|
def create_custom_forward(module):
|
||||||
if use_unified_sequence_parallel:
|
def custom_forward(*inputs):
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
return module(*inputs)
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
return custom_forward
|
||||||
|
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
for block in dit.blocks:
|
# blocks
|
||||||
x = block(x, context, t_mod, freqs)
|
for block_id, block in enumerate(dit.blocks):
|
||||||
|
if dit.training and use_gradient_checkpointing:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
if controlnet_res_stack is not None:
|
||||||
|
x = x + controlnet_res_stack[block_id]
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
|
|
||||||
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
|
|
||||||
|
|
||||||
|Identity Image|Generated Image|
|
|
||||||
|-|-|
|
|
||||||
|||
|
|
||||||
|||
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import torch
|
|
||||||
from diffsynth import ModelManager, FluxImagePipeline, download_models, ControlNetConfigUnit
|
|
||||||
from modelscope import dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
if importlib.util.find_spec("facexlib") is None:
|
|
||||||
raise ImportError("You are using InifiniteYou. It depends on facexlib, which is not installed. Please install it with `pip install facexlib`.")
|
|
||||||
if importlib.util.find_spec("insightface") is None:
|
|
||||||
raise ImportError("You are using InifiniteYou. It depends on insightface, which is not installed. Please install it with `pip install insightface`.")
|
|
||||||
|
|
||||||
download_models(["InfiniteYou"])
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
|
||||||
model_manager.load_models([
|
|
||||||
[
|
|
||||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
|
||||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
|
||||||
],
|
|
||||||
"models/InfiniteYou/image_proj_model.bin",
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
pipe = FluxImagePipeline.from_model_manager(
|
|
||||||
model_manager,
|
|
||||||
controlnet_config_units=[
|
|
||||||
ControlNetConfigUnit(
|
|
||||||
processor_id="none",
|
|
||||||
model_path=[
|
|
||||||
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors',
|
|
||||||
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors'
|
|
||||||
],
|
|
||||||
scale=1.0
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/infiniteyou/*")
|
|
||||||
|
|
||||||
prompt = "A man, portrait, cinematic"
|
|
||||||
id_image = "data/examples/infiniteyou/man.jpg"
|
|
||||||
id_image = Image.open(id_image).convert('RGB')
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt, seed=1,
|
|
||||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
|
||||||
num_inference_steps=50, embedded_guidance=3.5,
|
|
||||||
height=1024, width=1024,
|
|
||||||
)
|
|
||||||
image.save("man.jpg")
|
|
||||||
|
|
||||||
prompt = "A woman, portrait, cinematic"
|
|
||||||
id_image = "data/examples/infiniteyou/woman.jpg"
|
|
||||||
id_image = Image.open(id_image).convert('RGB')
|
|
||||||
image = pipe(
|
|
||||||
prompt=prompt, seed=1,
|
|
||||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
|
||||||
num_inference_steps=50, embedded_guidance=3.5,
|
|
||||||
height=1024, width=1024,
|
|
||||||
)
|
|
||||||
image.save("woman.jpg")
|
|
||||||
@@ -10,52 +10,34 @@ cd DiffSynth-Studio
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
## Model Zoo
|
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
|
||||||
|
|
||||||
|Developer|Name|Link|Scripts|
|
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||||
|-|-|-|-|
|
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||||
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||||
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
||||||
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
|
||||||
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
|
||||||
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|
|
||||||
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|
|
||||||
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|
|
||||||
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|
|
||||||
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
|
||||||
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
|
||||||
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
|
||||||
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
|
||||||
|
|
||||||
Base model features
|
## Inference
|
||||||
|
|
||||||
||Text-to-video|Image-to-video|End frame|Control|
|
### Wan-Video-1.3B-T2V
|
||||||
|-|-|-|-|-|
|
|
||||||
|1.3B text-to-video|✅||||
|
|
||||||
|14B text-to-video|✅||||
|
|
||||||
|14B image-to-video 480P||✅|||
|
|
||||||
|14B image-to-video 720P||✅|||
|
|
||||||
|1.3B InP||✅|✅||
|
|
||||||
|14B InP||✅|✅||
|
|
||||||
|1.3B Control||||✅|
|
|
||||||
|14B Control||||✅|
|
|
||||||
|
|
||||||
Adapter model compatibility
|
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
|
||||||
|
|
||||||
||1.3B text-to-video|1.3B InP|
|
Required VRAM: 6G
|
||||||
|-|-|-|
|
|
||||||
|1.3B aesthetics LoRA|✅||
|
|
||||||
|1.3B Highres-fix LoRA|✅||
|
|
||||||
|1.3B ExVideo LoRA|✅||
|
|
||||||
|1.3B Speed Control adapter|✅|✅|
|
|
||||||
|
|
||||||
## VRAM Usage
|
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||||
|
|
||||||
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
Put sunglasses on the dog.
|
||||||
|
|
||||||
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
|
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||||
|
|
||||||
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
|
||||||
|
|
||||||
|
### Wan-Video-14B-T2V
|
||||||
|
|
||||||
|
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||||
|
|
||||||
|
We present a detailed table here. The model is tested on a single A100.
|
||||||
|
|
||||||
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
||||||
|-|-|-|-|-|
|
|-|-|-|-|-|
|
||||||
@@ -65,46 +47,17 @@ We present a detailed table here. The model (14B text-to-video) is tested on a s
|
|||||||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
||||||
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
||||||
|
|
||||||
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
|
||||||
|
|
||||||
## Efficient Attention Implementation
|
|
||||||
|
|
||||||
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
|
|
||||||
|
|
||||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
|
||||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
|
||||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
|
||||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
|
||||||
|
|
||||||
## Acceleration
|
|
||||||
|
|
||||||
We support multiple acceleration solutions:
|
|
||||||
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
|
|
||||||
|
|
||||||
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install xfuser>=0.4.3
|
|
||||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
|
||||||
```
|
|
||||||
|
|
||||||
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
|
|
||||||
|
|
||||||
## Gallery
|
|
||||||
|
|
||||||
1.3B text-to-video.
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
|
||||||
|
|
||||||
Put sunglasses on the dog.
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
|
||||||
|
|
||||||
14B text-to-video.
|
|
||||||
|
|
||||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||||
|
|
||||||
14B image-to-video.
|
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
||||||
|
|
||||||
|
### Wan-Video-14B-I2V
|
||||||
|
|
||||||
|
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
||||||
|
|
||||||
|
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,12 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class TextVideoDataset(torch.utils.data.Dataset):
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
|
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||||
metadata = pd.read_csv(metadata_path)
|
metadata = pd.read_csv(metadata_path)
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
if os.path.exists(os.path.join(base_path, "train")):
|
||||||
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
|
else:
|
||||||
|
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
|
||||||
self.text = metadata["text"].to_list()
|
self.text = metadata["text"].to_list()
|
||||||
|
|
||||||
self.max_num_frames = max_num_frames
|
self.max_num_frames = max_num_frames
|
||||||
@@ -23,6 +26,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
self.is_i2v = is_i2v
|
self.is_i2v = is_i2v
|
||||||
|
self.target_fps = target_fps
|
||||||
|
|
||||||
self.frame_process = v2.Compose([
|
self.frame_process = v2.Compose([
|
||||||
v2.CenterCrop(size=(height, width)),
|
v2.CenterCrop(size=(height, width)),
|
||||||
@@ -71,8 +75,15 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def load_video(self, file_path):
|
def load_video(self, file_path):
|
||||||
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
|
start_frame_id = 0
|
||||||
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
|
if self.target_fps is None:
|
||||||
|
frame_interval = self.frame_interval
|
||||||
|
else:
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
fps = reader.get_meta_data()["fps"]
|
||||||
|
reader.close()
|
||||||
|
frame_interval = max(round(fps / self.target_fps), 1)
|
||||||
|
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
@@ -95,17 +106,20 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
def __getitem__(self, data_id):
|
def __getitem__(self, data_id):
|
||||||
text = self.text[data_id]
|
text = self.text[data_id]
|
||||||
path = self.path[data_id]
|
path = self.path[data_id]
|
||||||
if self.is_image(path):
|
try:
|
||||||
|
if self.is_image(path):
|
||||||
|
if self.is_i2v:
|
||||||
|
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||||
|
video = self.load_image(path)
|
||||||
|
else:
|
||||||
|
video = self.load_video(path)
|
||||||
if self.is_i2v:
|
if self.is_i2v:
|
||||||
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
video, first_frame = video
|
||||||
video = self.load_image(path)
|
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||||
else:
|
else:
|
||||||
video = self.load_video(path)
|
data = {"text": text, "video": video, "path": path}
|
||||||
if self.is_i2v:
|
except:
|
||||||
video, first_frame = video
|
data = None
|
||||||
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
|
||||||
else:
|
|
||||||
data = {"text": text, "video": video, "path": path}
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -115,7 +129,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class LightningModelForDataProcess(pl.LightningModule):
|
class LightningModelForDataProcess(pl.LightningModule):
|
||||||
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_path = [text_encoder_path, vae_path]
|
model_path = [text_encoder_path, vae_path]
|
||||||
if image_encoder_path is not None:
|
if image_encoder_path is not None:
|
||||||
@@ -125,9 +139,13 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
data = batch[0]
|
||||||
|
if data is None or data["video"] is None:
|
||||||
|
return
|
||||||
|
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
|
||||||
|
|
||||||
self.pipe.device = self.device
|
self.pipe.device = self.device
|
||||||
if video is not None:
|
if video is not None:
|
||||||
@@ -144,28 +162,49 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
else:
|
else:
|
||||||
image_emb = {}
|
image_emb = {}
|
||||||
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||||
|
if self.redirected_tensor_path is not None:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(self.redirected_tensor_path, path)
|
||||||
torch.save(data, path + ".tensors.pth")
|
torch.save(data, path + ".tensors.pth")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TensorDataset(torch.utils.data.Dataset):
|
class TensorDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, steps_per_epoch):
|
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||||
metadata = pd.read_csv(metadata_path)
|
if os.path.exists(metadata_path):
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
metadata = pd.read_csv(metadata_path)
|
||||||
print(len(self.path), "videos in metadata.")
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
print(len(self.path), "videos in metadata.")
|
||||||
|
if redirected_tensor_path is None:
|
||||||
|
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||||
|
else:
|
||||||
|
cached_path = []
|
||||||
|
for path in self.path:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(redirected_tensor_path, path)
|
||||||
|
if os.path.exists(path + ".tensors.pth"):
|
||||||
|
cached_path.append(path + ".tensors.pth")
|
||||||
|
self.path = cached_path
|
||||||
|
else:
|
||||||
|
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||||
|
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||||
print(len(self.path), "tensors cached in metadata.")
|
print(len(self.path), "tensors cached in metadata.")
|
||||||
assert len(self.path) > 0
|
assert len(self.path) > 0
|
||||||
|
|
||||||
self.steps_per_epoch = steps_per_epoch
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
while True:
|
||||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
try:
|
||||||
path = self.path[data_id]
|
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||||
data = torch.load(path, weights_only=True, map_location="cpu")
|
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||||
return data
|
path = self.path[data_id]
|
||||||
|
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -323,6 +362,18 @@ def parse_args():
|
|||||||
default="./",
|
default="./",
|
||||||
help="Path to save the model.",
|
help="Path to save the model.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to metadata.csv.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--redirected_tensor_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to save cached tensors.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text_encoder_path",
|
"--text_encoder_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -389,6 +440,12 @@ def parse_args():
|
|||||||
default=81,
|
default=81,
|
||||||
help="Number of frames.",
|
help="Number of frames.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_fps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Expected FPS for sampling frames.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--height",
|
"--height",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -500,19 +557,21 @@ def parse_args():
|
|||||||
def data_process(args):
|
def data_process(args):
|
||||||
dataset = TextVideoDataset(
|
dataset = TextVideoDataset(
|
||||||
args.dataset_path,
|
args.dataset_path,
|
||||||
os.path.join(args.dataset_path, "metadata.csv"),
|
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||||
max_num_frames=args.num_frames,
|
max_num_frames=args.num_frames,
|
||||||
frame_interval=1,
|
frame_interval=1,
|
||||||
num_frames=args.num_frames,
|
num_frames=args.num_frames,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
width=args.width,
|
width=args.width,
|
||||||
is_i2v=args.image_encoder_path is not None
|
is_i2v=args.image_encoder_path is not None,
|
||||||
|
target_fps=args.target_fps,
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=args.dataloader_num_workers
|
num_workers=args.dataloader_num_workers,
|
||||||
|
collate_fn=lambda x: x,
|
||||||
)
|
)
|
||||||
model = LightningModelForDataProcess(
|
model = LightningModelForDataProcess(
|
||||||
text_encoder_path=args.text_encoder_path,
|
text_encoder_path=args.text_encoder_path,
|
||||||
@@ -521,6 +580,7 @@ def data_process(args):
|
|||||||
tiled=args.tiled,
|
tiled=args.tiled,
|
||||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||||
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
)
|
)
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
@@ -533,8 +593,9 @@ def data_process(args):
|
|||||||
def train(args):
|
def train(args):
|
||||||
dataset = TensorDataset(
|
dataset = TensorDataset(
|
||||||
args.dataset_path,
|
args.dataset_path,
|
||||||
os.path.join(args.dataset_path, "metadata.csv"),
|
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||||
steps_per_epoch=args.steps_per_epoch,
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
)
|
)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|||||||
626
examples/wanvideo/train_wan_t2v_controlnet.py
Normal file
626
examples/wanvideo/train_wan_t2v_controlnet.py
Normal file
@@ -0,0 +1,626 @@
|
|||||||
|
import torch, os, imageio, argparse
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from einops import rearrange
|
||||||
|
import lightning as pl
|
||||||
|
import pandas as pd
|
||||||
|
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from diffsynth.models.wan_video_controlnet import WanControlNetModel
|
||||||
|
from diffsynth.pipelines.wan_video import model_fn_wan_video
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||||
|
metadata = pd.read_csv(metadata_path)
|
||||||
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
|
self.controlnet_path = [os.path.join(base_path, file_name) for file_name in metadata["controlnet_file_name"]]
|
||||||
|
self.text = metadata["text"].to_list()
|
||||||
|
|
||||||
|
self.max_num_frames = max_num_frames
|
||||||
|
self.frame_interval = frame_interval
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.is_i2v = is_i2v
|
||||||
|
self.target_fps = target_fps
|
||||||
|
|
||||||
|
self.frame_process = v2.Compose([
|
||||||
|
v2.CenterCrop(size=(height, width)),
|
||||||
|
v2.Resize(size=(height, width), antialias=True),
|
||||||
|
v2.ToTensor(),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def crop_and_resize(self, image):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(self.width / width, self.height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||||
|
reader.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
first_frame = None
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
frame = self.crop_and_resize(frame)
|
||||||
|
if first_frame is None:
|
||||||
|
first_frame = np.array(frame)
|
||||||
|
frame = frame_process(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
frames = torch.stack(frames, dim=0)
|
||||||
|
frames = rearrange(frames, "T C H W -> C T H W")
|
||||||
|
|
||||||
|
if self.is_i2v:
|
||||||
|
return frames, first_frame
|
||||||
|
else:
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, file_path):
|
||||||
|
start_frame_id = 0
|
||||||
|
if self.target_fps is None:
|
||||||
|
frame_interval = self.frame_interval
|
||||||
|
else:
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
fps = reader.get_meta_data()["fps"]
|
||||||
|
reader.close()
|
||||||
|
frame_interval = max(round(fps / self.target_fps), 1)
|
||||||
|
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def is_image(self, file_path):
|
||||||
|
file_ext_name = file_path.split(".")[-1]
|
||||||
|
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(self, file_path):
|
||||||
|
frame = Image.open(file_path).convert("RGB")
|
||||||
|
frame = self.crop_and_resize(frame)
|
||||||
|
frame = self.frame_process(frame)
|
||||||
|
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
text = self.text[data_id]
|
||||||
|
path = self.path[data_id]
|
||||||
|
controlnet_path = self.controlnet_path[data_id]
|
||||||
|
try:
|
||||||
|
if self.is_image(path):
|
||||||
|
if self.is_i2v:
|
||||||
|
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||||
|
video = self.load_image(path)
|
||||||
|
else:
|
||||||
|
video = self.load_video(path)
|
||||||
|
controlnet_frames = self.load_video(controlnet_path)
|
||||||
|
if self.is_i2v:
|
||||||
|
video, first_frame = video
|
||||||
|
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||||
|
else:
|
||||||
|
data = {"text": text, "video": video, "path": path, "controlnet_frames": controlnet_frames}
|
||||||
|
except:
|
||||||
|
data = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModelForDataProcess(pl.LightningModule):
|
||||||
|
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
|
||||||
|
super().__init__()
|
||||||
|
model_path = [text_encoder_path, vae_path]
|
||||||
|
if image_encoder_path is not None:
|
||||||
|
model_path.append(image_encoder_path)
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
|
model_manager.load_models(model_path)
|
||||||
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
data = batch[0]
|
||||||
|
if data is None or data["video"] is None:
|
||||||
|
return
|
||||||
|
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
|
||||||
|
controlnet_frames = data["controlnet_frames"].unsqueeze(0)
|
||||||
|
|
||||||
|
self.pipe.device = self.device
|
||||||
|
if video is not None:
|
||||||
|
# prompt
|
||||||
|
prompt_emb = self.pipe.encode_prompt(text)
|
||||||
|
# video
|
||||||
|
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
|
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||||
|
# ControlNet video
|
||||||
|
controlnet_frames = controlnet_frames.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
|
controlnet_kwargs = self.pipe.prepare_controlnet(controlnet_frames, **self.tiler_kwargs)
|
||||||
|
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"][0]
|
||||||
|
# image
|
||||||
|
if "first_frame" in batch:
|
||||||
|
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||||
|
_, _, num_frames, height, width = video.shape
|
||||||
|
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||||
|
else:
|
||||||
|
image_emb = {}
|
||||||
|
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb, "controlnet_kwargs": controlnet_kwargs}
|
||||||
|
if self.redirected_tensor_path is not None:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(self.redirected_tensor_path, path)
|
||||||
|
torch.save(data, path + ".tensors.pth")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TensorDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
metadata = pd.read_csv(metadata_path)
|
||||||
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
|
print(len(self.path), "videos in metadata.")
|
||||||
|
if redirected_tensor_path is None:
|
||||||
|
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||||
|
else:
|
||||||
|
cached_path = []
|
||||||
|
for path in self.path:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(redirected_tensor_path, path)
|
||||||
|
if os.path.exists(path + ".tensors.pth"):
|
||||||
|
cached_path.append(path + ".tensors.pth")
|
||||||
|
self.path = cached_path
|
||||||
|
else:
|
||||||
|
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||||
|
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||||
|
print(len(self.path), "tensors cached in metadata.")
|
||||||
|
assert len(self.path) > 0
|
||||||
|
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||||
|
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||||
|
path = self.path[data_id]
|
||||||
|
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModelForTrain(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dit_path,
|
||||||
|
learning_rate=1e-5,
|
||||||
|
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||||
|
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||||
|
pretrained_lora_path=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
|
if os.path.isfile(dit_path):
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
else:
|
||||||
|
dit_path = dit_path.split(",")
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
|
||||||
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
self.freeze_parameters()
|
||||||
|
|
||||||
|
state_dict = load_state_dict(dit_path, torch_dtype=torch.bfloat16)
|
||||||
|
state_dict, config = WanControlNetModel.state_dict_converter().from_base_model(state_dict)
|
||||||
|
self.pipe.controlnet = WanControlNetModel(**config).to(torch.bfloat16)
|
||||||
|
self.pipe.controlnet.load_state_dict(state_dict)
|
||||||
|
self.pipe.controlnet.train()
|
||||||
|
self.pipe.controlnet.requires_grad_(True)
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_parameters(self):
|
||||||
|
# Freeze parameters
|
||||||
|
self.pipe.requires_grad_(False)
|
||||||
|
self.pipe.eval()
|
||||||
|
self.pipe.denoising_model().train()
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Data
|
||||||
|
latents = batch["latents"].to(self.device)
|
||||||
|
controlnet_kwargs = batch["controlnet_kwargs"]
|
||||||
|
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"].to(self.device)
|
||||||
|
prompt_emb = batch["prompt_emb"]
|
||||||
|
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||||
|
image_emb = batch["image_emb"]
|
||||||
|
if "clip_feature" in image_emb:
|
||||||
|
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||||
|
if "y" in image_emb:
|
||||||
|
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
self.pipe.device = self.device
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
noise_pred = model_fn_wan_video(
|
||||||
|
dit=self.pipe.dit, controlnet=self.pipe.controlnet,
|
||||||
|
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **controlnet_kwargs,
|
||||||
|
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.controlnet.parameters())
|
||||||
|
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
checkpoint.clear()
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.controlnet.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
state_dict = self.pipe.controlnet.state_dict()
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in trainable_param_names:
|
||||||
|
lora_state_dict[name] = param
|
||||||
|
checkpoint.update(lora_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default="data_process",
|
||||||
|
required=True,
|
||||||
|
choices=["data_process", "train"],
|
||||||
|
help="Task. `data_process` or `train`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The path of the Dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="./",
|
||||||
|
help="Path to save the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to metadata.csv.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--redirected_tensor_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to save cached tensors.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of text encoder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of image encoder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dit_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of DiT.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tiled",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_size_height",
|
||||||
|
type=int,
|
||||||
|
default=34,
|
||||||
|
help="Tile size (height) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_size_width",
|
||||||
|
type=int,
|
||||||
|
default=34,
|
||||||
|
help="Tile size (width) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_stride_height",
|
||||||
|
type=int,
|
||||||
|
default=18,
|
||||||
|
help="Tile stride (height) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_stride_width",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Tile stride (width) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps_per_epoch",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Number of steps per epoch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_frames",
|
||||||
|
type=int,
|
||||||
|
default=81,
|
||||||
|
help="Number of frames.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_fps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Expected FPS for sampling frames.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--height",
|
||||||
|
type=int,
|
||||||
|
default=480,
|
||||||
|
help="Image height.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--width",
|
||||||
|
type=int,
|
||||||
|
default=832,
|
||||||
|
help="Image width.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataloader_num_workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=1e-5,
|
||||||
|
help="Learning rate.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--accumulate_grad_batches",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The number of batches in gradient accumulation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_epochs",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_target_modules",
|
||||||
|
type=str,
|
||||||
|
default="q,k,v,o,ffn.0,ffn.2",
|
||||||
|
help="Layers with LoRA modules.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--init_lora_weights",
|
||||||
|
type=str,
|
||||||
|
default="kaiming",
|
||||||
|
choices=["gaussian", "kaiming"],
|
||||||
|
help="The initializing method of LoRA weight.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training_strategy",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||||
|
help="Training strategy",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_rank",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="The dimension of the LoRA update matrices.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_alpha",
|
||||||
|
type=float,
|
||||||
|
default=4.0,
|
||||||
|
help="The weight of the LoRA update matrices.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_gradient_checkpointing",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use gradient checkpointing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_gradient_checkpointing_offload",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use gradient checkpointing offload.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_architecture",
|
||||||
|
type=str,
|
||||||
|
default="lora",
|
||||||
|
choices=["lora", "full"],
|
||||||
|
help="Model structure to train. LoRA training or full training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_lora_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_swanlab",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use SwanLab logger.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--swanlab_mode",
|
||||||
|
default=None,
|
||||||
|
help="SwanLab mode (cloud or local).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def data_process(args):
|
||||||
|
dataset = TextVideoDataset(
|
||||||
|
args.dataset_path,
|
||||||
|
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||||
|
max_num_frames=args.num_frames,
|
||||||
|
frame_interval=1,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
is_i2v=args.image_encoder_path is not None,
|
||||||
|
target_fps=args.target_fps,
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=args.dataloader_num_workers,
|
||||||
|
collate_fn=lambda x: x,
|
||||||
|
)
|
||||||
|
model = LightningModelForDataProcess(
|
||||||
|
text_encoder_path=args.text_encoder_path,
|
||||||
|
image_encoder_path=args.image_encoder_path,
|
||||||
|
vae_path=args.vae_path,
|
||||||
|
tiled=args.tiled,
|
||||||
|
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||||
|
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
|
)
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
)
|
||||||
|
trainer.test(model, dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
dataset = TensorDataset(
|
||||||
|
args.dataset_path,
|
||||||
|
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||||
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=args.dataloader_num_workers
|
||||||
|
)
|
||||||
|
model = LightningModelForTrain(
|
||||||
|
dit_path=args.dit_path,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
train_architecture=args.train_architecture,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
init_lora_weights=args.init_lora_weights,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
|
)
|
||||||
|
if args.use_swanlab:
|
||||||
|
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||||
|
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||||
|
swanlab_config.update(vars(args))
|
||||||
|
swanlab_logger = SwanLabLogger(
|
||||||
|
project="wan",
|
||||||
|
name="wan",
|
||||||
|
config=swanlab_config,
|
||||||
|
mode=args.swanlab_mode,
|
||||||
|
logdir=os.path.join(args.output_path, "swanlog"),
|
||||||
|
)
|
||||||
|
logger = [swanlab_logger]
|
||||||
|
else:
|
||||||
|
logger = None
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
precision="bf16",
|
||||||
|
strategy=args.training_strategy,
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
trainer.fit(model, dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
if args.task == "data_process":
|
||||||
|
data_process(args)
|
||||||
|
elif args.task == "train":
|
||||||
|
train(args)
|
||||||
691
examples/wanvideo/train_wan_t2v_motion.py
Normal file
691
examples/wanvideo/train_wan_t2v_motion.py
Normal file
@@ -0,0 +1,691 @@
|
|||||||
|
import torch, os, imageio, argparse
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from einops import rearrange
|
||||||
|
import lightning as pl
|
||||||
|
import pandas as pd
|
||||||
|
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||||
|
from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
from diffsynth.pipelines.wan_video import model_fn_wan_video
|
||||||
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
|
||||||
|
metadata = pd.read_csv(metadata_path)
|
||||||
|
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
|
||||||
|
self.text = metadata["text"].to_list()
|
||||||
|
|
||||||
|
self.max_num_frames = max_num_frames
|
||||||
|
self.frame_interval = frame_interval
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.is_i2v = is_i2v
|
||||||
|
self.target_fps = target_fps
|
||||||
|
|
||||||
|
self.frame_process = v2.Compose([
|
||||||
|
v2.CenterCrop(size=(height, width)),
|
||||||
|
v2.Resize(size=(height, width), antialias=True),
|
||||||
|
v2.ToTensor(),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def crop_and_resize(self, image):
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(self.width / width, self.height / height)
|
||||||
|
image = torchvision.transforms.functional.resize(
|
||||||
|
image,
|
||||||
|
(round(height*scale), round(width*scale)),
|
||||||
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||||
|
reader.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
first_frame = None
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
frame = self.crop_and_resize(frame)
|
||||||
|
if first_frame is None:
|
||||||
|
first_frame = np.array(frame)
|
||||||
|
frame = frame_process(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
frames = torch.stack(frames, dim=0)
|
||||||
|
frames = rearrange(frames, "T C H W -> C T H W")
|
||||||
|
|
||||||
|
if self.is_i2v:
|
||||||
|
return frames, first_frame
|
||||||
|
else:
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, file_path):
|
||||||
|
start_frame_id = 0
|
||||||
|
if self.target_fps is None:
|
||||||
|
frame_interval = self.frame_interval
|
||||||
|
else:
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
fps = reader.get_meta_data()["fps"]
|
||||||
|
reader.close()
|
||||||
|
frame_interval = max(round(fps / self.target_fps), 1)
|
||||||
|
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def is_image(self, file_path):
|
||||||
|
file_ext_name = file_path.split(".")[-1]
|
||||||
|
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(self, file_path):
|
||||||
|
frame = Image.open(file_path).convert("RGB")
|
||||||
|
frame = self.crop_and_resize(frame)
|
||||||
|
first_frame = frame
|
||||||
|
frame = self.frame_process(frame)
|
||||||
|
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
text = self.text[data_id]
|
||||||
|
path = self.path[data_id]
|
||||||
|
try:
|
||||||
|
if self.is_image(path):
|
||||||
|
if self.is_i2v:
|
||||||
|
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||||
|
video = self.load_image(path)
|
||||||
|
else:
|
||||||
|
video = self.load_video(path)
|
||||||
|
if self.is_i2v:
|
||||||
|
video, first_frame = video
|
||||||
|
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||||
|
else:
|
||||||
|
data = {"text": text, "video": video, "path": path}
|
||||||
|
except:
|
||||||
|
data = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModelForDataProcess(pl.LightningModule):
|
||||||
|
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
|
||||||
|
super().__init__()
|
||||||
|
model_path = [text_encoder_path, vae_path]
|
||||||
|
if image_encoder_path is not None:
|
||||||
|
model_path.append(image_encoder_path)
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
|
model_manager.load_models(model_path)
|
||||||
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
|
||||||
|
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
data = batch[0]
|
||||||
|
if data is None or data["video"] is None:
|
||||||
|
return
|
||||||
|
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
|
||||||
|
|
||||||
|
self.pipe.device = self.device
|
||||||
|
if video is not None:
|
||||||
|
# prompt
|
||||||
|
prompt_emb = self.pipe.encode_prompt(text)
|
||||||
|
# video
|
||||||
|
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
|
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||||
|
# image
|
||||||
|
if "first_frame" in batch:
|
||||||
|
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||||
|
_, _, num_frames, height, width = video.shape
|
||||||
|
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||||
|
else:
|
||||||
|
image_emb = {}
|
||||||
|
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||||
|
if self.redirected_tensor_path is not None:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(self.redirected_tensor_path, path)
|
||||||
|
torch.save(data, path + ".tensors.pth")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TensorDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
metadata = pd.read_csv(metadata_path)
|
||||||
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
|
print(len(self.path), "videos in metadata.")
|
||||||
|
if redirected_tensor_path is None:
|
||||||
|
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||||
|
else:
|
||||||
|
cached_path = []
|
||||||
|
for path in self.path:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(redirected_tensor_path, path)
|
||||||
|
if os.path.exists(path + ".tensors.pth"):
|
||||||
|
cached_path.append(path + ".tensors.pth")
|
||||||
|
self.path = cached_path
|
||||||
|
else:
|
||||||
|
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||||
|
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||||
|
print(len(self.path), "tensors cached in metadata.")
|
||||||
|
assert len(self.path) > 0
|
||||||
|
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.redirected_tensor_path = redirected_tensor_path
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||||
|
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||||
|
path = self.path[data_id]
|
||||||
|
data = torch.load(path, weights_only=True, map_location="cpu")
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModelForTrain(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dit_path,
|
||||||
|
learning_rate=1e-5,
|
||||||
|
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||||
|
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||||
|
pretrained_lora_path=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
|
if os.path.isfile(dit_path):
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
else:
|
||||||
|
dit_path = dit_path.split(",")
|
||||||
|
model_manager.load_models([dit_path])
|
||||||
|
|
||||||
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
self.freeze_parameters()
|
||||||
|
|
||||||
|
self.pipe.motion_controller = WanMotionControllerModel().to(torch.bfloat16)
|
||||||
|
self.pipe.motion_controller.init()
|
||||||
|
self.pipe.motion_controller.requires_grad_(True)
|
||||||
|
self.pipe.motion_controller.train()
|
||||||
|
self.motion_bucket_manager = MotionBucketManager()
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_parameters(self):
|
||||||
|
# Freeze parameters
|
||||||
|
self.pipe.requires_grad_(False)
|
||||||
|
self.pipe.eval()
|
||||||
|
self.pipe.dit.train()
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
|
||||||
|
# Add LoRA to UNet
|
||||||
|
self.lora_alpha = lora_alpha
|
||||||
|
if init_lora_weights == "kaiming":
|
||||||
|
init_lora_weights = True
|
||||||
|
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=lora_rank,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
init_lora_weights=init_lora_weights,
|
||||||
|
target_modules=lora_target_modules.split(","),
|
||||||
|
)
|
||||||
|
model = inject_adapter_in_model(lora_config, model)
|
||||||
|
for param in model.parameters():
|
||||||
|
# Upcast LoRA parameters into fp32
|
||||||
|
if param.requires_grad:
|
||||||
|
param.data = param.to(torch.float32)
|
||||||
|
|
||||||
|
# Lora pretrained lora weights
|
||||||
|
if pretrained_lora_path is not None:
|
||||||
|
state_dict = load_state_dict(pretrained_lora_path)
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
|
all_keys = [i for i, _ in model.named_parameters()]
|
||||||
|
num_updated_keys = len(all_keys) - len(missing_keys)
|
||||||
|
num_unexpected_keys = len(unexpected_keys)
|
||||||
|
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Data
|
||||||
|
latents = batch["latents"].to(self.device)
|
||||||
|
prompt_emb = batch["prompt_emb"]
|
||||||
|
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||||
|
image_emb = batch["image_emb"]
|
||||||
|
if "clip_feature" in image_emb:
|
||||||
|
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||||
|
if "y" in image_emb:
|
||||||
|
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
self.pipe.device = self.device
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||||
|
extra_input = self.pipe.prepare_extra_input(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||||
|
motion_bucket_id = self.motion_bucket_manager(latents)
|
||||||
|
motion_bucket_kwargs = self.pipe.prepare_motion_bucket_id(motion_bucket_id)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
noise_pred = model_fn_wan_video(
|
||||||
|
dit=self.pipe.dit, motion_controller=self.pipe.motion_controller,
|
||||||
|
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **motion_bucket_kwargs,
|
||||||
|
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.motion_controller.parameters())
|
||||||
|
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
checkpoint.clear()
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_controller.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
state_dict = self.pipe.motion_controller.state_dict()
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in trainable_param_names:
|
||||||
|
lora_state_dict[name] = param
|
||||||
|
checkpoint.update(lora_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MotionBucketManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.thresholds = [
|
||||||
|
0.093750000, 0.094726562, 0.100585938, 0.100585938, 0.108886719, 0.109375000, 0.118652344, 0.127929688, 0.127929688, 0.130859375,
|
||||||
|
0.133789062, 0.137695312, 0.138671875, 0.138671875, 0.139648438, 0.143554688, 0.143554688, 0.147460938, 0.149414062, 0.149414062,
|
||||||
|
0.152343750, 0.153320312, 0.154296875, 0.154296875, 0.157226562, 0.163085938, 0.163085938, 0.164062500, 0.165039062, 0.166992188,
|
||||||
|
0.173828125, 0.179687500, 0.180664062, 0.184570312, 0.187500000, 0.188476562, 0.188476562, 0.189453125, 0.189453125, 0.202148438,
|
||||||
|
0.206054688, 0.210937500, 0.210937500, 0.211914062, 0.214843750, 0.214843750, 0.216796875, 0.216796875, 0.216796875, 0.218750000,
|
||||||
|
0.218750000, 0.221679688, 0.222656250, 0.227539062, 0.229492188, 0.230468750, 0.236328125, 0.243164062, 0.243164062, 0.245117188,
|
||||||
|
0.253906250, 0.253906250, 0.255859375, 0.259765625, 0.275390625, 0.275390625, 0.277343750, 0.279296875, 0.279296875, 0.279296875,
|
||||||
|
0.292968750, 0.292968750, 0.302734375, 0.306640625, 0.312500000, 0.312500000, 0.326171875, 0.330078125, 0.332031250, 0.332031250,
|
||||||
|
0.337890625, 0.343750000, 0.343750000, 0.351562500, 0.355468750, 0.357421875, 0.361328125, 0.367187500, 0.382812500, 0.388671875,
|
||||||
|
0.392578125, 0.392578125, 0.392578125, 0.404296875, 0.404296875, 0.425781250, 0.433593750, 0.507812500, 0.519531250, 0.539062500,
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_motion_score(self, frames):
|
||||||
|
score = frames[:, :, 1:, :, :].std(dim=2).mean().tolist()
|
||||||
|
return score
|
||||||
|
|
||||||
|
def get_bucket_id(self, motion_score):
|
||||||
|
for bucket_id in range(len(self.thresholds) - 1):
|
||||||
|
if self.thresholds[bucket_id + 1] > motion_score:
|
||||||
|
return bucket_id
|
||||||
|
return len(self.thresholds)
|
||||||
|
|
||||||
|
def __call__(self, frames):
|
||||||
|
score = self.get_motion_score(frames)
|
||||||
|
bucket_id = self.get_bucket_id(score)
|
||||||
|
return bucket_id
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default="data_process",
|
||||||
|
required=True,
|
||||||
|
choices=["data_process", "train"],
|
||||||
|
help="Task. `data_process` or `train`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The path of the Dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="./",
|
||||||
|
help="Path to save the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to metadata.csv.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--redirected_tensor_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to save cached tensors.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of text encoder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_encoder_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of image encoder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dit_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path of DiT.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tiled",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_size_height",
|
||||||
|
type=int,
|
||||||
|
default=34,
|
||||||
|
help="Tile size (height) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_size_width",
|
||||||
|
type=int,
|
||||||
|
default=34,
|
||||||
|
help="Tile size (width) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_stride_height",
|
||||||
|
type=int,
|
||||||
|
default=18,
|
||||||
|
help="Tile stride (height) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tile_stride_width",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Tile stride (width) in VAE.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps_per_epoch",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Number of steps per epoch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_frames",
|
||||||
|
type=int,
|
||||||
|
default=81,
|
||||||
|
help="Number of frames.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_fps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Expected FPS for sampling frames.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--height",
|
||||||
|
type=int,
|
||||||
|
default=480,
|
||||||
|
help="Image height.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--width",
|
||||||
|
type=int,
|
||||||
|
default=832,
|
||||||
|
help="Image width.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataloader_num_workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=1e-5,
|
||||||
|
help="Learning rate.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--accumulate_grad_batches",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The number of batches in gradient accumulation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_epochs",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_target_modules",
|
||||||
|
type=str,
|
||||||
|
default="q,k,v,o,ffn.0,ffn.2",
|
||||||
|
help="Layers with LoRA modules.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--init_lora_weights",
|
||||||
|
type=str,
|
||||||
|
default="kaiming",
|
||||||
|
choices=["gaussian", "kaiming"],
|
||||||
|
help="The initializing method of LoRA weight.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training_strategy",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||||
|
help="Training strategy",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_rank",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="The dimension of the LoRA update matrices.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_alpha",
|
||||||
|
type=float,
|
||||||
|
default=4.0,
|
||||||
|
help="The weight of the LoRA update matrices.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_gradient_checkpointing",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use gradient checkpointing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_gradient_checkpointing_offload",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use gradient checkpointing offload.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_architecture",
|
||||||
|
type=str,
|
||||||
|
default="lora",
|
||||||
|
choices=["lora", "full"],
|
||||||
|
help="Model structure to train. LoRA training or full training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_lora_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_swanlab",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to use SwanLab logger.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--swanlab_mode",
|
||||||
|
default=None,
|
||||||
|
help="SwanLab mode (cloud or local).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def data_process(args):
|
||||||
|
dataset = TextVideoDataset(
|
||||||
|
args.dataset_path,
|
||||||
|
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||||
|
max_num_frames=args.num_frames,
|
||||||
|
frame_interval=1,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
is_i2v=args.image_encoder_path is not None,
|
||||||
|
target_fps=args.target_fps,
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=args.dataloader_num_workers,
|
||||||
|
collate_fn=lambda x: x,
|
||||||
|
)
|
||||||
|
model = LightningModelForDataProcess(
|
||||||
|
text_encoder_path=args.text_encoder_path,
|
||||||
|
image_encoder_path=args.image_encoder_path,
|
||||||
|
vae_path=args.vae_path,
|
||||||
|
tiled=args.tiled,
|
||||||
|
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||||
|
tile_stride=(args.tile_stride_height, args.tile_stride_width),
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
|
)
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
)
|
||||||
|
trainer.test(model, dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
def get_motion_thresholds(dataloader):
|
||||||
|
scores = []
|
||||||
|
for data in tqdm(dataloader):
|
||||||
|
scores.append(data["latents"][:, :, 1:, :, :].std(dim=2).mean().tolist())
|
||||||
|
scores = sorted(scores)
|
||||||
|
for i in range(100):
|
||||||
|
s = scores[int(i/100 * len(scores))]
|
||||||
|
print("%.9f" % s, end=", ")
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
dataset = TensorDataset(
|
||||||
|
args.dataset_path,
|
||||||
|
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
|
||||||
|
steps_per_epoch=args.steps_per_epoch,
|
||||||
|
redirected_tensor_path=args.redirected_tensor_path,
|
||||||
|
)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=args.dataloader_num_workers
|
||||||
|
)
|
||||||
|
model = LightningModelForTrain(
|
||||||
|
dit_path=args.dit_path,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
train_architecture=args.train_architecture,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
init_lora_weights=args.init_lora_weights,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
|
)
|
||||||
|
if args.use_swanlab:
|
||||||
|
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||||
|
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||||
|
swanlab_config.update(vars(args))
|
||||||
|
swanlab_logger = SwanLabLogger(
|
||||||
|
project="wan",
|
||||||
|
name="wan",
|
||||||
|
config=swanlab_config,
|
||||||
|
mode=args.swanlab_mode,
|
||||||
|
logdir=os.path.join(args.output_path, "swanlog"),
|
||||||
|
)
|
||||||
|
logger = [swanlab_logger]
|
||||||
|
else:
|
||||||
|
logger = None
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
precision="bf16",
|
||||||
|
strategy=args.training_strategy,
|
||||||
|
default_root_dir=args.output_path,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
trainer.fit(model, dataloader)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
if args.task == "data_process":
|
||||||
|
data_process(args)
|
||||||
|
elif args.task == "train":
|
||||||
|
train(args)
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
# Download models
|
|
||||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
|
||||||
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(device="cpu")
|
|
||||||
model_manager.load_models(
|
|
||||||
[
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
|
||||||
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
|
|
||||||
],
|
|
||||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
|
||||||
)
|
|
||||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
|
||||||
|
|
||||||
# Text-to-video
|
|
||||||
video = pipe(
|
|
||||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
num_inference_steps=50,
|
|
||||||
seed=1, tiled=True,
|
|
||||||
motion_bucket_id=0
|
|
||||||
)
|
|
||||||
save_video(video, "video_slow.mp4", fps=15, quality=5)
|
|
||||||
|
|
||||||
video = pipe(
|
|
||||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
num_inference_steps=50,
|
|
||||||
seed=1, tiled=True,
|
|
||||||
motion_bucket_id=100
|
|
||||||
)
|
|
||||||
save_video(video, "video_fast.mp4", fps=15, quality=5)
|
|
||||||
@@ -44,28 +44,11 @@ class LitModel(pl.LightningModule):
|
|||||||
|
|
||||||
def configure_model(self):
|
def configure_model(self):
|
||||||
tp_mesh = self.device_mesh["tensor_parallel"]
|
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||||
plan = {
|
|
||||||
"text_embedding.0": ColwiseParallel(),
|
|
||||||
"text_embedding.2": RowwiseParallel(),
|
|
||||||
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
|
|
||||||
"text_embedding.0": ColwiseParallel(),
|
|
||||||
"text_embedding.2": RowwiseParallel(),
|
|
||||||
"blocks.0": PrepareModuleInput(
|
|
||||||
input_layouts=(Replicate(), None, None, None),
|
|
||||||
desired_input_layouts=(Replicate(), None, None, None),
|
|
||||||
),
|
|
||||||
"head": PrepareModuleInput(
|
|
||||||
input_layouts=(Replicate(), None),
|
|
||||||
desired_input_layouts=(Replicate(), None),
|
|
||||||
use_local_output=True,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
|
|
||||||
for block_id, block in enumerate(self.pipe.dit.blocks):
|
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||||
layer_tp_plan = {
|
layer_tp_plan = {
|
||||||
"self_attn": PrepareModuleInput(
|
"self_attn": PrepareModuleInput(
|
||||||
input_layouts=(Shard(1), Replicate()),
|
input_layouts=(Replicate(), Replicate()),
|
||||||
desired_input_layouts=(Shard(1), Shard(0)),
|
desired_input_layouts=(Replicate(), Shard(0)),
|
||||||
),
|
),
|
||||||
"self_attn.q": SequenceParallel(),
|
"self_attn.q": SequenceParallel(),
|
||||||
"self_attn.k": SequenceParallel(),
|
"self_attn.k": SequenceParallel(),
|
||||||
@@ -76,11 +59,11 @@ class LitModel(pl.LightningModule):
|
|||||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
),
|
),
|
||||||
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()),
|
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||||
|
|
||||||
"cross_attn": PrepareModuleInput(
|
"cross_attn": PrepareModuleInput(
|
||||||
input_layouts=(Shard(1), Replicate()),
|
input_layouts=(Replicate(), Replicate()),
|
||||||
desired_input_layouts=(Shard(1), Replicate()),
|
desired_input_layouts=(Replicate(), Replicate()),
|
||||||
),
|
),
|
||||||
"cross_attn.q": SequenceParallel(),
|
"cross_attn.q": SequenceParallel(),
|
||||||
"cross_attn.k": SequenceParallel(),
|
"cross_attn.k": SequenceParallel(),
|
||||||
@@ -91,18 +74,10 @@ class LitModel(pl.LightningModule):
|
|||||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
),
|
),
|
||||||
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False),
|
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||||
|
|
||||||
"ffn.0": ColwiseParallel(input_layouts=Shard(1)),
|
"ffn.0": ColwiseParallel(),
|
||||||
"ffn.2": RowwiseParallel(output_layouts=Replicate()),
|
"ffn.2": RowwiseParallel(),
|
||||||
|
|
||||||
"norm1": SequenceParallel(use_local_output=True),
|
|
||||||
"norm2": SequenceParallel(use_local_output=True),
|
|
||||||
"norm3": SequenceParallel(use_local_output=True),
|
|
||||||
"gate": PrepareModuleInput(
|
|
||||||
input_layouts=(Shard(1), Replicate(), Replicate()),
|
|
||||||
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
parallelize_module(
|
parallelize_module(
|
||||||
module=block,
|
module=block,
|
||||||
@@ -121,6 +96,7 @@ class LitModel(pl.LightningModule):
|
|||||||
save_video(video, output_path, fps=15, quality=5)
|
save_video(video, output_path, fps=15, quality=5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
# Download models
|
|
||||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(device="cpu")
|
|
||||||
model_manager.load_models(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
|
||||||
],
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
|
||||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
|
||||||
],
|
|
||||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
|
||||||
)
|
|
||||||
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="nccl",
|
|
||||||
init_method="env://",
|
|
||||||
)
|
|
||||||
from xfuser.core.distributed import (initialize_model_parallel,
|
|
||||||
init_distributed_environment)
|
|
||||||
init_distributed_environment(
|
|
||||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
|
||||||
|
|
||||||
initialize_model_parallel(
|
|
||||||
sequence_parallel_degree=dist.get_world_size(),
|
|
||||||
ring_degree=1,
|
|
||||||
ulysses_degree=dist.get_world_size(),
|
|
||||||
)
|
|
||||||
torch.cuda.set_device(dist.get_rank())
|
|
||||||
|
|
||||||
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device=f"cuda:{dist.get_rank()}",
|
|
||||||
use_usp=True if dist.get_world_size() > 1 else False)
|
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
|
||||||
|
|
||||||
# Text-to-video
|
|
||||||
video = pipe(
|
|
||||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
num_inference_steps=50,
|
|
||||||
seed=0, tiled=True
|
|
||||||
)
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
save_video(video, "video1.mp4", fps=25, quality=5)
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
|
||||||
from modelscope import snapshot_download, dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
# Download models
|
|
||||||
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(device="cpu")
|
|
||||||
model_manager.load_models(
|
|
||||||
[
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
|
||||||
],
|
|
||||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
|
||||||
)
|
|
||||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
|
||||||
|
|
||||||
# Download example image
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
|
||||||
)
|
|
||||||
image = Image.open("data/examples/wan/input_image.jpg")
|
|
||||||
|
|
||||||
# Image-to-video
|
|
||||||
video = pipe(
|
|
||||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
num_inference_steps=50,
|
|
||||||
input_image=image,
|
|
||||||
# You can input `end_image=xxx` to control the last frame of the video.
|
|
||||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
|
||||||
seed=1, tiled=True
|
|
||||||
)
|
|
||||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
|
||||||
from modelscope import snapshot_download, dataset_snapshot_download
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
# Download models
|
|
||||||
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(device="cpu")
|
|
||||||
model_manager.load_models(
|
|
||||||
[
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
|
|
||||||
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
|
||||||
],
|
|
||||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
|
||||||
)
|
|
||||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
|
||||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
|
||||||
|
|
||||||
# Download example video
|
|
||||||
dataset_snapshot_download(
|
|
||||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
|
||||||
local_dir="./",
|
|
||||||
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Control-to-video
|
|
||||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
|
||||||
video = pipe(
|
|
||||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
|
||||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
||||||
num_inference_steps=50,
|
|
||||||
control_video=control_video, height=832, width=576, num_frames=49,
|
|
||||||
seed=1, tiled=True
|
|
||||||
)
|
|
||||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
|
||||||
2
setup.py
2
setup.py
@@ -14,7 +14,7 @@ else:
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffsynth",
|
name="diffsynth",
|
||||||
version="1.1.7",
|
version="1.1.2",
|
||||||
description="Enjoy the magic of Diffusion models!",
|
description="Enjoy the magic of Diffusion models!",
|
||||||
author="Artiprocher",
|
author="Artiprocher",
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
|
|||||||
Reference in New Issue
Block a user