Compare commits

..

9 Commits

Author SHA1 Message Date
Artiprocher
05094710e3 support motion controller 2025-03-24 19:07:58 +08:00
Artiprocher
105eaf0f49 controlnet 2025-03-21 11:09:56 +08:00
Artiprocher
6cd032e846 skip bad files 2025-03-19 14:49:18 +08:00
Artiprocher
9d8130b48d ignore metadata 2025-03-19 11:36:07 +08:00
Artiprocher
ce848a3d1a bugfix 2025-03-18 19:36:58 +08:00
Artiprocher
a8ce9fef33 support redirected tensor path 2025-03-18 19:24:27 +08:00
Artiprocher
8da0d183a2 support target fps 2025-03-18 17:31:15 +08:00
Artiprocher
4b2b3dda94 support target fps 2025-03-18 17:30:13 +08:00
Artiprocher
b1fabbc6b0 skip bad videos 2025-03-18 17:24:39 +08:00
31 changed files with 1753 additions and 1691 deletions

View File

@@ -20,7 +20,7 @@ jobs:
with: with:
python-version: '3.10' python-version: '3.10'
- name: Install wheel - name: Install wheel
run: pip install wheel==0.44.0 && pip install -r requirements.txt run: pip install wheel && pip install -r requirements.txt
- name: Build DiffSynth - name: Build DiffSynth
run: python setup.py sdist bdist_wheel run: python setup.py sdist bdist_wheel
- name: Publish package to PyPI - name: Publish package to PyPI

View File

@@ -13,15 +13,9 @@ Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
## Introduction ## Introduction
Welcome to the magic world of Diffusion models! DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
DiffSynth consists of two open-source projects: Until now, DiffSynth Studio has supported the following models:
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
Until now, DiffSynth-Studio has supported the following models:
* [Wan-Video](https://github.com/Wan-Video/Wan2.1) * [Wan-Video](https://github.com/Wan-Video/Wan2.1)
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V) * [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
@@ -42,11 +36,7 @@ Until now, DiffSynth-Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News ## News
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details. - **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). - **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
@@ -83,7 +73,7 @@ Until now, DiffSynth-Studio has supported the following models:
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md) - Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and additional models will be available soon. - LoRA, ControlNet, and additional models will be available soon.
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames. - **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/). - Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1). - Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).

View File

@@ -37,7 +37,6 @@ from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT from ..models.cog_dit import CogDiT
@@ -59,7 +58,6 @@ from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_motion_controller import WanMotionControllerModel
model_loader_configs = [ model_loader_configs = [
@@ -106,8 +104,6 @@ model_loader_configs = [
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"), (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"), (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
@@ -121,16 +117,11 @@ model_loader_configs = [
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -607,25 +598,6 @@ preset_models_on_modelscope = {
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
], ],
}, },
"InfiniteYou":{
"file_list":[
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
],
"load_path":[
[
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
],
"models/InfiniteYou/image_proj_model.bin",
],
},
# ESRGAN # ESRGAN
"ESRGAN_x4": [ "ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -785,7 +757,6 @@ Preset_model_id: TypeAlias = Literal[
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"InstantX/FLUX.1-dev-IP-Adapter", "InstantX/FLUX.1-dev-IP-Adapter",
"InfiniteYou",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt", "QwenPrompt",
"OmostPrompt", "OmostPrompt",

View File

@@ -1,125 +0,0 @@
import torch, os, json, torchvision
from PIL import Image
from torchvision.transforms import v2
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(self, base_path, keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1])
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1)
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
while True:
try:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
except:
continue
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
while True:
try:
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
except:
continue
def __len__(self):
return self.steps_per_epoch

View File

@@ -1,129 +0,0 @@
import torch
from typing import Optional
from einops import rearrange
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
s_per_rank = x.shape[1]
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)
def usp_dit_forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Context Parallel
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
x = self.head(x, t)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, (f, h, w))
return x
def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
x = xFuserLongContextAttention()(
None,
query=q,
key=k,
value=v,
)
x = x.flatten(2)
del q, k, v
torch.cuda.empty_cache()
return self.o(x)

View File

@@ -5,7 +5,7 @@ import pathlib
import re import re
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
# from turtle import forward from turtle import forward
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch

View File

@@ -318,8 +318,6 @@ class FluxControlNetStateDictConverter:
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4} extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c": elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1} extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
else: else:
extra_kwargs = {} extra_kwargs = {}
return state_dict_, extra_kwargs return state_dict_, extra_kwargs

View File

@@ -20,11 +20,10 @@ class RoPEEmbedding(torch.nn.Module):
self.axes_dim = axes_dim self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int, device="cpu") -> torch.Tensor: def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even." assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = scale.to(device)
omega = 1.0 / (theta**scale) omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape batch_size, seq_length = pos.shape
@@ -37,9 +36,9 @@ class RoPEEmbedding(torch.nn.Module):
return out.float() return out.float()
def forward(self, ids, device="cpu"): def forward(self, ids):
n_axes = ids.shape[-1] n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta, device) for i in range(n_axes)], dim=-3) emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1) return emb.unsqueeze(1)

View File

@@ -1,128 +0,0 @@
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class InfiniteYouImageProjector(nn.Module):
def __init__(
self,
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=8,
embedding_dim=512,
output_dim=4096,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
@staticmethod
def state_dict_converter():
return FluxInfiniteYouImageProjectorStateDictConverter()
class FluxInfiniteYouImageProjectorStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict['image_proj']

View File

@@ -1,31 +0,0 @@
from .sd3_dit import TimestepEmbeddings
from .flux_dit import RoPEEmbedding
import torch
from einops import repeat
class FluxReferenceEmbedder(torch.nn.Module):
def __init__(self):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.idx_embedder = TimestepEmbeddings(256, 256)
self.proj = torch.nn.Linear(3072, 3072)
def forward(self, image_ids, idx, dtype, device):
pos_emb = self.pos_embedder(image_ids, device=device)
idx_emb = self.idx_embedder(idx, dtype=dtype).to(device)
length = pos_emb.shape[2]
pos_emb = repeat(pos_emb, "B N L C H W -> 1 N (B L) C H W")
idx_emb = repeat(idx_emb, "B (C H W) -> 1 1 (B L) C H W", C=64, H=2, W=2, L=length)
image_rotary_emb = pos_emb + idx_emb
return image_rotary_emb
def init(self):
self.idx_embedder.timestep_embedder[-1].load_state_dict({
"weight": torch.zeros((256, 256)),
"bias": torch.zeros((256,))
}),
self.proj.load_state_dict({
"weight": torch.eye(3072),
"bias": torch.zeros((3072,))
})

View File

@@ -367,20 +367,5 @@ class FluxLoRAConverter:
return state_dict_ return state_dict_
class WanLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, **kwargs):
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
return state_dict
@staticmethod
def align_to_diffsynth_format(state_dict, **kwargs):
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
return state_dict
def get_lora_loaders(): def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()] return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -0,0 +1,204 @@
import torch
import torch.nn as nn
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d
from .utils import hash_state_dict_keys
class WanControlNetModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
has_image_input: bool,
):
super().__init__()
self.dim = dim
self.freq_dim = freq_dim
self.has_image_input = has_image_input
self.patch_size = patch_size
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
for _ in range(num_layers)
])
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Linear(dim, dim, bias=False)
for _ in range(num_layers)
])
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0], h=grid_size[1], w=grid_size[2],
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
)
def forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
controlnet_conditioning: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x = x + self.controlnet_conv_in(controlnet_conditioning)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
res_stack = []
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
res_stack.append(x)
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
return controlnet_res_stack
@staticmethod
def state_dict_converter():
return WanControlNetModelStateDictConverter()
class WanControlNetModelStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict
def from_base_model(self, state_dict):
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
config = {
"has_image_input": False,
"patch_size": [1, 2, 2],
"in_dim": 16,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else:
config = {}
state_dict_ = {}
dtype, device = None, None
for name, param in state_dict.items():
if name.startswith("head."):
continue
state_dict_[name] = param
dtype, device = param.dtype, param.device
for block_id in range(config["num_layers"]):
zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device)
state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone()
state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device)
state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device)
return state_dict_, config

View File

@@ -183,13 +183,6 @@ class CrossAttention(nn.Module):
return self.o(x) return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module): class DiTBlock(nn.Module):
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
@@ -206,17 +199,16 @@ class DiTBlock(nn.Module):
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim)) approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs): def forward(self, x, context, t_mod, freqs):
# msa: multi-head self-attention mlp: multi-layer perceptron # msa: multi-head self-attention mlp: multi-layer perceptron
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa) input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) x = x + gate_msa * self.self_attn(input_x, freqs)
x = x + self.cross_attn(self.norm3(x), context) x = x + self.cross_attn(self.norm3(x), context)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x)) x = x + gate_mlp * self.ffn(input_x)
return x return x
@@ -493,62 +485,6 @@ class WanModelStateDictConverter:
"num_layers": 40, "num_layers": 40,
"eps": 1e-6 "eps": 1e-6
} }
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else: else:
config = {} config = {}
return state_dict, config return state_dict, config

View File

@@ -25,20 +25,3 @@ class WanMotionControllerModel(torch.nn.Module):
state_dict = self.linear[-1].state_dict() state_dict = self.linear[-1].state_dict()
state_dict = {i: state_dict[i] * 0 for i in state_dict} state_dict = {i: state_dict[i] * 0 for i in state_dict}
self.linear[-1].load_state_dict(state_dict) self.linear[-1].load_state_dict(state_dict)
@staticmethod
def state_dict_converter():
return WanMotionControllerModelDictConverter()
class WanMotionControllerModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,12 +1,10 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..models.flux_reference_embedder import FluxReferenceEmbedder
from ..prompters import FluxPrompter from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler from ..schedulers import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from typing import List from typing import List
import torch import torch
from einops import rearrange
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@@ -33,8 +31,6 @@ class FluxImagePipeline(BasePipeline):
self.controlnet: FluxMultiControlNetManager = None self.controlnet: FluxMultiControlNetManager = None
self.ipadapter: FluxIpAdapter = None self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None self.ipadapter_image_encoder: SiglipVisionModel = None
self.infinityou_processor: InfinitYou = None
self.reference_embedder: FluxReferenceEmbedder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder'] self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
@@ -166,11 +162,6 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter = model_manager.fetch_model("flux_ipadapter") self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
# InfiniteYou
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if self.image_proj_model is not None:
self.infinityou_processor = InfinitYou(device=self.device)
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None): def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
@@ -358,27 +349,6 @@ class FluxImagePipeline(BasePipeline):
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
if self.infinityou_processor is not None and id_image is not None:
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
else:
return {}, controlnet_image
def prepare_reference_images(self, reference_images, tiled=False, tile_size=64, tile_stride=32):
if reference_images is not None:
hidden_states_ref = []
for reference_image in reference_images:
self.load_models_to_device(['vae_encoder'])
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
hidden_states_ref.append(latents)
hidden_states_ref = torch.concat(hidden_states_ref, dim=0)
return {"hidden_states_ref": hidden_states_ref}
else:
return {"hidden_states_ref": None}
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
@@ -412,11 +382,6 @@ class FluxImagePipeline(BasePipeline):
eligen_entity_masks=None, eligen_entity_masks=None,
enable_eligen_on_negative=False, enable_eligen_on_negative=False,
enable_eligen_inpaint=False, enable_eligen_inpaint=False,
# InfiniteYou
infinityou_id_image=None,
infinityou_guidance=1.0,
# Reference images
reference_images=None,
# TeaCache # TeaCache
tea_cache_l1_thresh=None, tea_cache_l1_thresh=None,
# Tile # Tile
@@ -444,9 +409,6 @@ class FluxImagePipeline(BasePipeline):
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# InfiniteYou
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
# Entity control # Entity control
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale) eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
@@ -456,9 +418,6 @@ class FluxImagePipeline(BasePipeline):
# ControlNets # ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Reference images
reference_kwargs = self.prepare_reference_images(reference_images, **tiler_kwargs)
# TeaCache # TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
@@ -469,9 +428,9 @@ class FluxImagePipeline(BasePipeline):
# Positive side # Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder, dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep, hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **reference_kwargs, **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
) )
noise_pred_posi = self.control_noise_via_local_prompts( noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -486,9 +445,9 @@ class FluxImagePipeline(BasePipeline):
if cfg_scale != 1.0: if cfg_scale != 1.0:
# Negative side # Negative side
noise_pred_nega = lets_dance_flux( noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, reference_embedder=self.reference_embedder, dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep, hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **reference_kwargs, **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -510,58 +469,6 @@ class FluxImagePipeline(BasePipeline):
return image return image
class InfinitYou:
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
self.device = device
self.torch_dtype = torch_dtype
insightface_root_path = 'models/InfiniteYou/insightface'
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
self.arcface_model = init_recognition_model('arcface', device=self.device)
def _detect_face(self, id_image_cv2):
face_info = self.app_640.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_320.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_160.get(id_image_cv2)
return face_info
def extract_arcface_bgr_embedding(self, in_image, landmark):
from insightface.utils import face_align
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.contiguous().to(self.device)
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
return face_emb
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
import cv2
if id_image is None:
return {'id_emb': None}, controlnet_image
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
face_info = self._detect_face(id_image_cv2)
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
if controlnet_image is None:
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
class TeaCache: class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh): def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
@@ -608,7 +515,6 @@ class TeaCache:
def lets_dance_flux( def lets_dance_flux(
dit: FluxDiT, dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None, controlnet: FluxMultiControlNetManager = None,
reference_embedder: FluxReferenceEmbedder = None,
hidden_states=None, hidden_states=None,
timestep=None, timestep=None,
prompt_emb=None, prompt_emb=None,
@@ -617,17 +523,13 @@ def lets_dance_flux(
text_ids=None, text_ids=None,
image_ids=None, image_ids=None,
controlnet_frames=None, controlnet_frames=None,
hidden_states_ref=None,
tiled=False, tiled=False,
tile_size=128, tile_size=128,
tile_stride=64, tile_stride=64,
entity_prompt_emb=None, entity_prompt_emb=None,
entity_masks=None, entity_masks=None,
ipadapter_kwargs_list={}, ipadapter_kwargs_list={},
id_emb=None,
infinityou_guidance=None,
tea_cache: TeaCache = None, tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs **kwargs
): ):
if tiled: if tiled:
@@ -671,9 +573,6 @@ def lets_dance_flux(
"tile_size": tile_size, "tile_size": tile_size,
"tile_stride": tile_stride, "tile_stride": tile_stride,
} }
if id_emb is not None:
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
controlnet_res_stack, controlnet_single_res_stack = controlnet( controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs controlnet_frames, **controlnet_extra_kwargs
) )
@@ -694,55 +593,28 @@ def lets_dance_flux(
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else: else:
prompt_emb = dit.context_embedder(prompt_emb) prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1), device=hidden_states.device) image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None attention_mask = None
# Reference images
if hidden_states_ref is not None:
# RoPE
image_ids_ref = dit.prepare_image_ids(hidden_states_ref)
idx = torch.arange(0, image_ids_ref.shape[0]).to(dtype=hidden_states.dtype, device=hidden_states.device) * 100
image_rotary_emb_ref = reference_embedder(image_ids_ref, idx, dtype=hidden_states.dtype, device=hidden_states.device)
image_rotary_emb = torch.cat((image_rotary_emb, image_rotary_emb_ref), dim=2)
# hidden_states
original_hidden_states_length = hidden_states.shape[1]
hidden_states_ref = dit.patchify(hidden_states_ref)
hidden_states_ref = dit.x_embedder(hidden_states_ref)
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
# TeaCache # TeaCache
if tea_cache is not None: if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning) tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else: else:
tea_cache_update = False tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update: if tea_cache_update:
hidden_states = tea_cache.update(hidden_states) hidden_states = tea_cache.update(hidden_states)
else: else:
# Joint Blocks # Joint Blocks
for block_id, block in enumerate(dit.blocks): for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing: hidden_states, prompt_emb = block(
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( hidden_states,
create_custom_forward(block), prompt_emb,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), conditioning,
use_reentrant=False, image_rotary_emb,
) attention_mask,
else: ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
hidden_states, prompt_emb = block( )
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id] hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -751,21 +623,14 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks) num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks): for block_id, block in enumerate(dit.single_blocks):
if use_gradient_checkpointing: hidden_states, prompt_emb = block(
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( hidden_states,
create_custom_forward(block), prompt_emb,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), conditioning,
use_reentrant=False, image_rotary_emb,
) attention_mask,
else: ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
hidden_states, prompt_emb = block( )
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
@@ -774,8 +639,6 @@ def lets_dance_flux(
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(hidden_states) tea_cache.store(hidden_states)
if hidden_states_ref is not None:
hidden_states = hidden_states[:, :original_hidden_states_length]
hidden_states = dit.final_norm_out(hidden_states, conditioning) hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states) hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width) hidden_states = dit.unpatchify(hidden_states, height, width)

View File

@@ -1,4 +1,3 @@
import types
from ..models import ModelManager from ..models import ModelManager
from ..models.wan_video_dit import WanModel from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_text_encoder import WanTextEncoder
@@ -18,6 +17,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_controlnet import WanControlNetModel
from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -32,11 +32,11 @@ class WanVideoPipeline(BasePipeline):
self.image_encoder: WanImageEncoder = None self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None self.dit: WanModel = None
self.vae: WanVideoVAE = None self.vae: WanVideoVAE = None
self.controlnet: WanControlNetModel = None
self.motion_controller: WanMotionControllerModel = None self.motion_controller: WanMotionControllerModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller'] self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet', 'motion_controller']
self.height_division_factor = 16 self.height_division_factor = 16
self.width_division_factor = 16 self.width_division_factor = 16
self.use_unified_sequence_parallel = False
def enable_vram_management(self, num_persistent_param_in_dit=None): def enable_vram_management(self, num_persistent_param_in_dit=None):
@@ -124,22 +124,6 @@ class WanVideoPipeline(BasePipeline):
computation_device=self.device, computation_device=self.device,
), ),
) )
if self.motion_controller is not None:
dtype = next(iter(self.motion_controller.parameters())).dtype
enable_vram_management(
self.motion_controller,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload() self.enable_cpu_offload()
@@ -152,24 +136,14 @@ class WanVideoPipeline(BasePipeline):
self.dit = model_manager.fetch_model("wan_video_dit") self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae") self.vae = model_manager.fetch_model("wan_video_vae")
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
if device is None: device = model_manager.device if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager) pipe.fetch_models(model_manager)
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
for block in pipe.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
pipe.sp_size = get_sequence_parallel_world_size()
pipe.use_unified_sequence_parallel = True
return pipe return pipe
@@ -178,26 +152,20 @@ class WanVideoPipeline(BasePipeline):
def encode_prompt(self, prompt, positive=True): def encode_prompt(self, prompt, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device) prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
return {"context": prompt_emb} return {"context": prompt_emb}
def encode_image(self, image, end_image, num_frames, height, width): def encode_image(self, image, num_frames, height, width):
image = self.preprocess_image(image.resize((width, height))).to(self.device) image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image]) clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0 msk[:, 1:] = 0
if end_image is not None:
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0] y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = torch.concat([msk, y]) y = torch.concat([msk, y])
y = y.unsqueeze(0) y = y.unsqueeze(0)
@@ -206,25 +174,6 @@ class WanVideoPipeline(BasePipeline):
return {"clip_feature": clip_context, "y": y} return {"clip_feature": clip_context, "y": y}
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
control_video = self.preprocess_images(control_video)
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return latents
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
if control_video is not None:
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
else:
y = y[:, -16:]
y = torch.concat([control_latents, y], dim=1)
return {"clip_feature": clip_feature, "y": y}
def tensor2video(self, frames): def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C") frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
@@ -246,8 +195,9 @@ class WanVideoPipeline(BasePipeline):
return frames return frames
def prepare_unified_sequence_parallel(self): def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return {"controlnet_conditioning": controlnet_conditioning}
def prepare_motion_bucket_id(self, motion_bucket_id): def prepare_motion_bucket_id(self, motion_bucket_id):
@@ -261,9 +211,7 @@ class WanVideoPipeline(BasePipeline):
prompt, prompt,
negative_prompt="", negative_prompt="",
input_image=None, input_image=None,
end_image=None,
input_video=None, input_video=None,
control_video=None,
denoising_strength=1.0, denoising_strength=1.0,
seed=None, seed=None,
rand_device="cpu", rand_device="cpu",
@@ -279,6 +227,7 @@ class WanVideoPipeline(BasePipeline):
tile_stride=(15, 26), tile_stride=(15, 26),
tea_cache_l1_thresh=None, tea_cache_l1_thresh=None,
tea_cache_model_id="", tea_cache_model_id="",
controlnet_frames=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
@@ -315,14 +264,18 @@ class WanVideoPipeline(BasePipeline):
# Encode image # Encode image
if input_image is not None and self.image_encoder is not None: if input_image is not None and self.image_encoder is not None:
self.load_models_to_device(["image_encoder", "vae"]) self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.encode_image(input_image, end_image, num_frames, height, width) image_emb = self.encode_image(input_image, num_frames, height, width)
else: else:
image_emb = {} image_emb = {}
# ControlNet # ControlNet
if control_video is not None: if self.controlnet is not None and controlnet_frames is not None:
self.load_models_to_device(["image_encoder", "vae"]) self.load_models_to_device(['vae', 'controlnet'])
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs) controlnet_frames = self.preprocess_images(controlnet_frames)
controlnet_frames = torch.stack(controlnet_frames, dim=2).to(dtype=self.torch_dtype, device=self.device)
controlnet_kwargs = self.prepare_controlnet(controlnet_frames)
else:
controlnet_kwargs = {}
# Motion Controller # Motion Controller
if self.motion_controller is not None and motion_bucket_id is not None: if self.motion_controller is not None and motion_bucket_id is not None:
@@ -337,27 +290,24 @@ class WanVideoPipeline(BasePipeline):
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
# Unified Sequence Parallel
usp_kwargs = self.prepare_unified_sequence_parallel()
# Denoise # Denoise
self.load_models_to_device(["dit", "motion_controller"]) self.load_models_to_device(["dit", "controlnet", "motion_controller"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference # Inference
noise_pred_posi = model_fn_wan_video( noise_pred_posi = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller, self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
x=latents, timestep=timestep, x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input, **prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **usp_kwargs, **motion_kwargs **tea_cache_posi, **controlnet_kwargs, **motion_kwargs,
) )
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video( noise_pred_nega = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller, self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller,
x=latents, timestep=timestep, x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input, **prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **usp_kwargs, **motion_kwargs **tea_cache_nega, **controlnet_kwargs, **motion_kwargs,
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -431,6 +381,7 @@ class TeaCache:
def model_fn_wan_video( def model_fn_wan_video(
dit: WanModel, dit: WanModel,
controlnet: WanControlNetModel = None,
motion_controller: WanMotionControllerModel = None, motion_controller: WanMotionControllerModel = None,
x: torch.Tensor = None, x: torch.Tensor = None,
timestep: torch.Tensor = None, timestep: torch.Tensor = None,
@@ -438,15 +389,22 @@ def model_fn_wan_video(
clip_feature: Optional[torch.Tensor] = None, clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None, tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False, controlnet_conditioning: Optional[torch.Tensor] = None,
motion_bucket_id: Optional[torch.Tensor] = None, motion_bucket_id: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs, **kwargs,
): ):
if use_unified_sequence_parallel: # ControlNet
import torch.distributed as dist if controlnet is not None and controlnet_conditioning is not None:
from xfuser.core.distributed import (get_sequence_parallel_rank, controlnet_res_stack = controlnet(
get_sequence_parallel_world_size, x, timestep=timestep, context=context, clip_feature=clip_feature, y=y,
get_sp_group) controlnet_conditioning=controlnet_conditioning,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
else:
controlnet_res_stack = None
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
@@ -473,21 +431,37 @@ def model_fn_wan_video(
else: else:
tea_cache_update = False tea_cache_update = False
# blocks def create_custom_forward(module):
if use_unified_sequence_parallel: def custom_forward(*inputs):
if dist.is_initialized() and dist.get_world_size() > 1: return module(*inputs)
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] return custom_forward
if tea_cache_update: if tea_cache_update:
x = tea_cache.update(x) x = tea_cache.update(x)
else: else:
for block in dit.blocks: # blocks
x = block(x, context, t_mod, freqs) for block_id, block in enumerate(dit.blocks):
if dit.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
if controlnet_res_stack is not None:
x = x + controlnet_res_stack[block_id]
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(x) tea_cache.store(x)
x = dit.head(x, t) x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w)) x = dit.unpatchify(x, (f, h, w))
return x return x

View File

@@ -1,7 +0,0 @@
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
|Identity Image|Generated Image|
|-|-|
|![man_id](https://github.com/user-attachments/assets/bbc38a91-966e-49e8-a0d7-c5467582ad1f)|![man](https://github.com/user-attachments/assets/0decd5e1-5f65-437c-98fa-90991b6f23c1)|
|![woman_id](https://github.com/user-attachments/assets/b2894695-690e-465b-929c-61e5dc57feeb)|![woman](https://github.com/user-attachments/assets/67cc7496-c4d3-4de1-a8f1-9eb4991d95e8)|

View File

@@ -1,58 +0,0 @@
import importlib
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models, ControlNetConfigUnit
from modelscope import dataset_snapshot_download
from PIL import Image
if importlib.util.find_spec("facexlib") is None:
raise ImportError("You are using InifiniteYou. It depends on facexlib, which is not installed. Please install it with `pip install facexlib`.")
if importlib.util.find_spec("insightface") is None:
raise ImportError("You are using InifiniteYou. It depends on insightface, which is not installed. Please install it with `pip install insightface`.")
download_models(["InfiniteYou"])
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_models([
[
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
],
"models/InfiniteYou/image_proj_model.bin",
])
pipe = FluxImagePipeline.from_model_manager(
model_manager,
controlnet_config_units=[
ControlNetConfigUnit(
processor_id="none",
model_path=[
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors',
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors'
],
scale=1.0
)
]
)
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/infiniteyou/*")
prompt = "A man, portrait, cinematic"
id_image = "data/examples/infiniteyou/man.jpg"
id_image = Image.open(id_image).convert('RGB')
image = pipe(
prompt=prompt, seed=1,
infinityou_id_image=id_image, infinityou_guidance=1.0,
num_inference_steps=50, embedded_guidance=3.5,
height=1024, width=1024,
)
image.save("man.jpg")
prompt = "A woman, portrait, cinematic"
id_image = "data/examples/infiniteyou/woman.jpg"
id_image = Image.open(id_image).convert('RGB')
image = pipe(
prompt=prompt, seed=1,
infinityou_id_image=id_image, infinityou_guidance=1.0,
num_inference_steps=50, embedded_guidance=3.5,
height=1024, width=1024,
)
image.save("woman.jpg")

View File

@@ -10,52 +10,34 @@ cd DiffSynth-Studio
pip install -e . pip install -e .
``` ```
## Model Zoo Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
|Developer|Name|Link|Scripts| * [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|-|-|-|-| * [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)| * [Sage Attention](https://github.com/thu-ml/SageAttention)
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)| * [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
Base model features ## Inference
||Text-to-video|Image-to-video|End frame|Control| ### Wan-Video-1.3B-T2V
|-|-|-|-|-|
|1.3B text-to-video|✅||||
|14B text-to-video|✅||||
|14B image-to-video 480P||✅|||
|14B image-to-video 720P||✅|||
|1.3B InP||✅|✅||
|14B InP||✅|✅||
|1.3B Control||||✅|
|14B Control||||✅|
Adapter model compatibility Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
||1.3B text-to-video|1.3B InP| Required VRAM: 6G
|-|-|-|
|1.3B aesthetics LoRA|✅||
|1.3B Highres-fix LoRA|✅||
|1.3B ExVideo LoRA|✅||
|1.3B Speed Control adapter|✅|✅|
## VRAM Usage https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). Put sunglasses on the dog.
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!). https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
We present a detailed table here. The model (14B text-to-video) is tested on a single A100. [TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
### Wan-Video-14B-T2V
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
We present a detailed table here. The model is tested on a single A100.
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting| |`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|-|-|-|-|-| |-|-|-|-|-|
@@ -65,46 +47,17 @@ We present a detailed table here. The model (14B text-to-video) is tested on a s
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes| |torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|torch.float8_e4m3fn|0|24.0s/it|10G|| |torch.float8_e4m3fn|0|24.0s/it|10G||
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
## Efficient Attention Implementation
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
## Acceleration
We support multiple acceleration solutions:
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
```bash
pip install xfuser>=0.4.3
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
```
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
## Gallery
1.3B text-to-video.
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
Put sunglasses on the dog.
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
14B text-to-video.
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
14B image-to-video. Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
### Wan-Video-14B-I2V
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39)
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75

View File

@@ -12,9 +12,12 @@ import numpy as np
class TextVideoDataset(torch.utils.data.Dataset): class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path) metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] if os.path.exists(os.path.join(base_path, "train")):
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
else:
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list() self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames self.max_num_frames = max_num_frames
@@ -23,6 +26,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
self.height = height self.height = height
self.width = width self.width = width
self.is_i2v = is_i2v self.is_i2v = is_i2v
self.target_fps = target_fps
self.frame_process = v2.Compose([ self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)), v2.CenterCrop(size=(height, width)),
@@ -71,8 +75,15 @@ class TextVideoDataset(torch.utils.data.Dataset):
def load_video(self, file_path): def load_video(self, file_path):
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] start_frame_id = 0
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) if self.target_fps is None:
frame_interval = self.frame_interval
else:
reader = imageio.get_reader(file_path)
fps = reader.get_meta_data()["fps"]
reader.close()
frame_interval = max(round(fps / self.target_fps), 1)
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
return frames return frames
@@ -95,17 +106,20 @@ class TextVideoDataset(torch.utils.data.Dataset):
def __getitem__(self, data_id): def __getitem__(self, data_id):
text = self.text[data_id] text = self.text[data_id]
path = self.path[data_id] path = self.path[data_id]
if self.is_image(path): try:
if self.is_image(path):
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
if self.is_i2v: if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") video, first_frame = video
video = self.load_image(path) data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else: else:
video = self.load_video(path) data = {"text": text, "video": video, "path": path}
if self.is_i2v: except:
video, first_frame = video data = None
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path}
return data return data
@@ -115,7 +129,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
class LightningModelForDataProcess(pl.LightningModule): class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
super().__init__() super().__init__()
model_path = [text_encoder_path, vae_path] model_path = [text_encoder_path, vae_path]
if image_encoder_path is not None: if image_encoder_path is not None:
@@ -125,9 +139,13 @@ class LightningModelForDataProcess(pl.LightningModule):
self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
self.redirected_tensor_path = redirected_tensor_path
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
text, video, path = batch["text"][0], batch["video"], batch["path"][0] data = batch[0]
if data is None or data["video"] is None:
return
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
self.pipe.device = self.device self.pipe.device = self.device
if video is not None: if video is not None:
@@ -144,28 +162,49 @@ class LightningModelForDataProcess(pl.LightningModule):
else: else:
image_emb = {} image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
if self.redirected_tensor_path is not None:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(self.redirected_tensor_path, path)
torch.save(data, path + ".tensors.pth") torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset): class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch): def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
metadata = pd.read_csv(metadata_path) if os.path.exists(metadata_path):
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] metadata = pd.read_csv(metadata_path)
print(len(self.path), "videos in metadata.") self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] print(len(self.path), "videos in metadata.")
if redirected_tensor_path is None:
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
else:
cached_path = []
for path in self.path:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(redirected_tensor_path, path)
if os.path.exists(path + ".tensors.pth"):
cached_path.append(path + ".tensors.pth")
self.path = cached_path
else:
print("Cannot find metadata.csv. Trying to search for tensor files.")
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
print(len(self.path), "tensors cached in metadata.") print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0 assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
self.redirected_tensor_path = redirected_tensor_path
def __getitem__(self, index): def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0] while True:
data_id = (data_id + index) % len(self.path) # For fixed seed. try:
path = self.path[data_id] data_id = torch.randint(0, len(self.path), (1,))[0]
data = torch.load(path, weights_only=True, map_location="cpu") data_id = (data_id + index) % len(self.path) # For fixed seed.
return data path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self): def __len__(self):
@@ -323,6 +362,18 @@ def parse_args():
default="./", default="./",
help="Path to save the model.", help="Path to save the model.",
) )
parser.add_argument(
"--metadata_path",
type=str,
default=None,
help="Path to metadata.csv.",
)
parser.add_argument(
"--redirected_tensor_path",
type=str,
default=None,
help="Path to save cached tensors.",
)
parser.add_argument( parser.add_argument(
"--text_encoder_path", "--text_encoder_path",
type=str, type=str,
@@ -389,6 +440,12 @@ def parse_args():
default=81, default=81,
help="Number of frames.", help="Number of frames.",
) )
parser.add_argument(
"--target_fps",
type=int,
default=None,
help="Expected FPS for sampling frames.",
)
parser.add_argument( parser.add_argument(
"--height", "--height",
type=int, type=int,
@@ -500,19 +557,21 @@ def parse_args():
def data_process(args): def data_process(args):
dataset = TextVideoDataset( dataset = TextVideoDataset(
args.dataset_path, args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"), os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
max_num_frames=args.num_frames, max_num_frames=args.num_frames,
frame_interval=1, frame_interval=1,
num_frames=args.num_frames, num_frames=args.num_frames,
height=args.height, height=args.height,
width=args.width, width=args.width,
is_i2v=args.image_encoder_path is not None is_i2v=args.image_encoder_path is not None,
target_fps=args.target_fps,
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
num_workers=args.dataloader_num_workers num_workers=args.dataloader_num_workers,
collate_fn=lambda x: x,
) )
model = LightningModelForDataProcess( model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path, text_encoder_path=args.text_encoder_path,
@@ -521,6 +580,7 @@ def data_process(args):
tiled=args.tiled, tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width), tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width), tile_stride=(args.tile_stride_height, args.tile_stride_width),
redirected_tensor_path=args.redirected_tensor_path,
) )
trainer = pl.Trainer( trainer = pl.Trainer(
accelerator="gpu", accelerator="gpu",
@@ -533,8 +593,9 @@ def data_process(args):
def train(args): def train(args):
dataset = TensorDataset( dataset = TensorDataset(
args.dataset_path, args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"), os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
steps_per_epoch=args.steps_per_epoch, steps_per_epoch=args.steps_per_epoch,
redirected_tensor_path=args.redirected_tensor_path,
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,

View File

@@ -0,0 +1,626 @@
import torch, os, imageio, argparse
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
from diffsynth.models.wan_video_controlnet import WanControlNetModel
from diffsynth.pipelines.wan_video import model_fn_wan_video
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
self.controlnet_path = [os.path.join(base_path, file_name) for file_name in metadata["controlnet_file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
self.frame_interval = frame_interval
self.num_frames = num_frames
self.height = height
self.width = width
self.is_i2v = is_i2v
self.target_fps = target_fps
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def crop_and_resize(self, image):
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
first_frame = None
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
if first_frame is None:
first_frame = np.array(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
if self.is_i2v:
return frames, first_frame
else:
return frames
def load_video(self, file_path):
start_frame_id = 0
if self.target_fps is None:
frame_interval = self.frame_interval
else:
reader = imageio.get_reader(file_path)
fps = reader.get_meta_data()["fps"]
reader.close()
frame_interval = max(round(fps / self.target_fps), 1)
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
return True
return False
def load_image(self, file_path):
frame = Image.open(file_path).convert("RGB")
frame = self.crop_and_resize(frame)
frame = self.frame_process(frame)
frame = rearrange(frame, "C H W -> C 1 H W")
return frame
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
controlnet_path = self.controlnet_path[data_id]
try:
if self.is_image(path):
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
controlnet_frames = self.load_video(controlnet_path)
if self.is_i2v:
video, first_frame = video
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path, "controlnet_frames": controlnet_frames}
except:
data = None
return data
def __len__(self):
return len(self.path)
class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
super().__init__()
model_path = [text_encoder_path, vae_path]
if image_encoder_path is not None:
model_path.append(image_encoder_path)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models(model_path)
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
self.redirected_tensor_path = redirected_tensor_path
def test_step(self, batch, batch_idx):
data = batch[0]
if data is None or data["video"] is None:
return
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
controlnet_frames = data["controlnet_frames"].unsqueeze(0)
self.pipe.device = self.device
if video is not None:
# prompt
prompt_emb = self.pipe.encode_prompt(text)
# video
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
# ControlNet video
controlnet_frames = controlnet_frames.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
controlnet_kwargs = self.pipe.prepare_controlnet(controlnet_frames, **self.tiler_kwargs)
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"][0]
# image
if "first_frame" in batch:
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
_, _, num_frames, height, width = video.shape
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
else:
image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb, "controlnet_kwargs": controlnet_kwargs}
if self.redirected_tensor_path is not None:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(self.redirected_tensor_path, path)
torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
if os.path.exists(metadata_path):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
if redirected_tensor_path is None:
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
else:
cached_path = []
for path in self.path:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(redirected_tensor_path, path)
if os.path.exists(path + ".tensors.pth"):
cached_path.append(path + ".tensors.pth")
self.path = cached_path
else:
print("Cannot find metadata.csv. Trying to search for tensor files.")
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch
self.redirected_tensor_path = redirected_tensor_path
def __getitem__(self, index):
while True:
try:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self):
return self.steps_per_epoch
class LightningModelForTrain(pl.LightningModule):
def __init__(
self,
dit_path,
learning_rate=1e-5,
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
pretrained_lora_path=None
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
if os.path.isfile(dit_path):
model_manager.load_models([dit_path])
else:
dit_path = dit_path.split(",")
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
state_dict = load_state_dict(dit_path, torch_dtype=torch.bfloat16)
state_dict, config = WanControlNetModel.state_dict_converter().from_base_model(state_dict)
self.pipe.controlnet = WanControlNetModel(**config).to(torch.bfloat16)
self.pipe.controlnet.load_state_dict(state_dict)
self.pipe.controlnet.train()
self.pipe.controlnet.requires_grad_(True)
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def training_step(self, batch, batch_idx):
# Data
latents = batch["latents"].to(self.device)
controlnet_kwargs = batch["controlnet_kwargs"]
controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
image_emb = batch["image_emb"]
if "clip_feature" in image_emb:
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
if "y" in image_emb:
image_emb["y"] = image_emb["y"][0].to(self.device)
# Loss
self.pipe.device = self.device
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
noise_pred = model_fn_wan_video(
dit=self.pipe.dit, controlnet=self.pipe.controlnet,
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **controlnet_kwargs,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.controlnet.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.controlnet.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.controlnet.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--task",
type=str,
default="data_process",
required=True,
choices=["data_process", "train"],
help="Task. `data_process` or `train`.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--metadata_path",
type=str,
default=None,
help="Path to metadata.csv.",
)
parser.add_argument(
"--redirected_tensor_path",
type=str,
default=None,
help="Path to save cached tensors.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help="Path of text encoder.",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
help="Path of image encoder.",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path of VAE.",
)
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path of DiT.",
)
parser.add_argument(
"--tiled",
default=False,
action="store_true",
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
)
parser.add_argument(
"--tile_size_height",
type=int,
default=34,
help="Tile size (height) in VAE.",
)
parser.add_argument(
"--tile_size_width",
type=int,
default=34,
help="Tile size (width) in VAE.",
)
parser.add_argument(
"--tile_stride_height",
type=int,
default=18,
help="Tile stride (height) in VAE.",
)
parser.add_argument(
"--tile_stride_width",
type=int,
default=16,
help="Tile stride (width) in VAE.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--num_frames",
type=int,
default=81,
help="Number of frames.",
)
parser.add_argument(
"--target_fps",
type=int,
default=None,
help="Expected FPS for sampling frames.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=832,
help="Image width.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to use gradient checkpointing offload.",
)
parser.add_argument(
"--train_architecture",
type=str,
default="lora",
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
parser.add_argument(
"--pretrained_lora_path",
type=str,
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
args = parser.parse_args()
return args
def data_process(args):
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
height=args.height,
width=args.width,
is_i2v=args.image_encoder_path is not None,
target_fps=args.target_fps,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers,
collate_fn=lambda x: x,
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,
image_encoder_path=args.image_encoder_path,
vae_path=args.vae_path,
tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width),
redirected_tensor_path=args.redirected_tensor_path,
)
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
default_root_dir=args.output_path,
)
trainer.test(model, dataloader)
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
steps_per_epoch=args.steps_per_epoch,
redirected_tensor_path=args.redirected_tensor_path,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForTrain(
dit_path=args.dit_path,
learning_rate=args.learning_rate,
train_architecture=args.train_architecture,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
pretrained_lora_path=args.pretrained_lora_path,
)
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=os.path.join(args.output_path, "swanlog"),
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
args = parse_args()
if args.task == "data_process":
data_process(args)
elif args.task == "train":
train(args)

View File

@@ -0,0 +1,691 @@
import torch, os, imageio, argparse
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel
from diffsynth.pipelines.wan_video import model_fn_wan_video
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
from tqdm import tqdm
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
self.frame_interval = frame_interval
self.num_frames = num_frames
self.height = height
self.width = width
self.is_i2v = is_i2v
self.target_fps = target_fps
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def crop_and_resize(self, image):
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
first_frame = None
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
if first_frame is None:
first_frame = np.array(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
if self.is_i2v:
return frames, first_frame
else:
return frames
def load_video(self, file_path):
start_frame_id = 0
if self.target_fps is None:
frame_interval = self.frame_interval
else:
reader = imageio.get_reader(file_path)
fps = reader.get_meta_data()["fps"]
reader.close()
frame_interval = max(round(fps / self.target_fps), 1)
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process)
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
return True
return False
def load_image(self, file_path):
frame = Image.open(file_path).convert("RGB")
frame = self.crop_and_resize(frame)
first_frame = frame
frame = self.frame_process(frame)
frame = rearrange(frame, "C H W -> C 1 H W")
return frame
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
try:
if self.is_image(path):
if self.is_i2v:
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
video = self.load_image(path)
else:
video = self.load_video(path)
if self.is_i2v:
video, first_frame = video
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
else:
data = {"text": text, "video": video, "path": path}
except:
data = None
return data
def __len__(self):
return len(self.path)
class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None):
super().__init__()
model_path = [text_encoder_path, vae_path]
if image_encoder_path is not None:
model_path.append(image_encoder_path)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models(model_path)
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
self.redirected_tensor_path = redirected_tensor_path
def test_step(self, batch, batch_idx):
data = batch[0]
if data is None or data["video"] is None:
return
text, video, path = data["text"], data["video"].unsqueeze(0), data["path"]
self.pipe.device = self.device
if video is not None:
# prompt
prompt_emb = self.pipe.encode_prompt(text)
# video
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
# image
if "first_frame" in batch:
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
_, _, num_frames, height, width = video.shape
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
else:
image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
if self.redirected_tensor_path is not None:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(self.redirected_tensor_path, path)
torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
if os.path.exists(metadata_path):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
if redirected_tensor_path is None:
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
else:
cached_path = []
for path in self.path:
path = path.replace("/", "_").replace("\\", "_")
path = os.path.join(redirected_tensor_path, path)
if os.path.exists(path + ".tensors.pth"):
cached_path.append(path + ".tensors.pth")
self.path = cached_path
else:
print("Cannot find metadata.csv. Trying to search for tensor files.")
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch
self.redirected_tensor_path = redirected_tensor_path
def __getitem__(self, index):
while True:
try:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self):
return self.steps_per_epoch
class LightningModelForTrain(pl.LightningModule):
def __init__(
self,
dit_path,
learning_rate=1e-5,
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
pretrained_lora_path=None
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
if os.path.isfile(dit_path):
model_manager.load_models([dit_path])
else:
dit_path = dit_path.split(",")
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.motion_controller = WanMotionControllerModel().to(torch.bfloat16)
self.pipe.motion_controller.init()
self.pipe.motion_controller.requires_grad_(True)
self.pipe.motion_controller.train()
self.motion_bucket_manager = MotionBucketManager()
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.dit.train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
# Upcast LoRA parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
def training_step(self, batch, batch_idx):
# Data
latents = batch["latents"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
image_emb = batch["image_emb"]
if "clip_feature" in image_emb:
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
if "y" in image_emb:
image_emb["y"] = image_emb["y"][0].to(self.device)
# Loss
self.pipe.device = self.device
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
motion_bucket_id = self.motion_bucket_manager(latents)
motion_bucket_kwargs = self.pipe.prepare_motion_bucket_id(motion_bucket_id)
# Compute loss
noise_pred = model_fn_wan_video(
dit=self.pipe.dit, motion_controller=self.pipe.motion_controller,
x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **motion_bucket_kwargs,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.motion_controller.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_controller.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.motion_controller.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
checkpoint.update(lora_state_dict)
class MotionBucketManager:
def __init__(self):
self.thresholds = [
0.093750000, 0.094726562, 0.100585938, 0.100585938, 0.108886719, 0.109375000, 0.118652344, 0.127929688, 0.127929688, 0.130859375,
0.133789062, 0.137695312, 0.138671875, 0.138671875, 0.139648438, 0.143554688, 0.143554688, 0.147460938, 0.149414062, 0.149414062,
0.152343750, 0.153320312, 0.154296875, 0.154296875, 0.157226562, 0.163085938, 0.163085938, 0.164062500, 0.165039062, 0.166992188,
0.173828125, 0.179687500, 0.180664062, 0.184570312, 0.187500000, 0.188476562, 0.188476562, 0.189453125, 0.189453125, 0.202148438,
0.206054688, 0.210937500, 0.210937500, 0.211914062, 0.214843750, 0.214843750, 0.216796875, 0.216796875, 0.216796875, 0.218750000,
0.218750000, 0.221679688, 0.222656250, 0.227539062, 0.229492188, 0.230468750, 0.236328125, 0.243164062, 0.243164062, 0.245117188,
0.253906250, 0.253906250, 0.255859375, 0.259765625, 0.275390625, 0.275390625, 0.277343750, 0.279296875, 0.279296875, 0.279296875,
0.292968750, 0.292968750, 0.302734375, 0.306640625, 0.312500000, 0.312500000, 0.326171875, 0.330078125, 0.332031250, 0.332031250,
0.337890625, 0.343750000, 0.343750000, 0.351562500, 0.355468750, 0.357421875, 0.361328125, 0.367187500, 0.382812500, 0.388671875,
0.392578125, 0.392578125, 0.392578125, 0.404296875, 0.404296875, 0.425781250, 0.433593750, 0.507812500, 0.519531250, 0.539062500,
]
def get_motion_score(self, frames):
score = frames[:, :, 1:, :, :].std(dim=2).mean().tolist()
return score
def get_bucket_id(self, motion_score):
for bucket_id in range(len(self.thresholds) - 1):
if self.thresholds[bucket_id + 1] > motion_score:
return bucket_id
return len(self.thresholds)
def __call__(self, frames):
score = self.get_motion_score(frames)
bucket_id = self.get_bucket_id(score)
return bucket_id
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--task",
type=str,
default="data_process",
required=True,
choices=["data_process", "train"],
help="Task. `data_process` or `train`.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--metadata_path",
type=str,
default=None,
help="Path to metadata.csv.",
)
parser.add_argument(
"--redirected_tensor_path",
type=str,
default=None,
help="Path to save cached tensors.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help="Path of text encoder.",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
help="Path of image encoder.",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path of VAE.",
)
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path of DiT.",
)
parser.add_argument(
"--tiled",
default=False,
action="store_true",
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
)
parser.add_argument(
"--tile_size_height",
type=int,
default=34,
help="Tile size (height) in VAE.",
)
parser.add_argument(
"--tile_size_width",
type=int,
default=34,
help="Tile size (width) in VAE.",
)
parser.add_argument(
"--tile_stride_height",
type=int,
default=18,
help="Tile stride (height) in VAE.",
)
parser.add_argument(
"--tile_stride_width",
type=int,
default=16,
help="Tile stride (width) in VAE.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--num_frames",
type=int,
default=81,
help="Number of frames.",
)
parser.add_argument(
"--target_fps",
type=int,
default=None,
help="Expected FPS for sampling frames.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=832,
help="Image width.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to use gradient checkpointing offload.",
)
parser.add_argument(
"--train_architecture",
type=str,
default="lora",
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
parser.add_argument(
"--pretrained_lora_path",
type=str,
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
args = parser.parse_args()
return args
def data_process(args):
dataset = TextVideoDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
height=args.height,
width=args.width,
is_i2v=args.image_encoder_path is not None,
target_fps=args.target_fps,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers,
collate_fn=lambda x: x,
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,
image_encoder_path=args.image_encoder_path,
vae_path=args.vae_path,
tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width),
redirected_tensor_path=args.redirected_tensor_path,
)
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
default_root_dir=args.output_path,
)
trainer.test(model, dataloader)
def get_motion_thresholds(dataloader):
scores = []
for data in tqdm(dataloader):
scores.append(data["latents"][:, :, 1:, :, :].std(dim=2).mean().tolist())
scores = sorted(scores)
for i in range(100):
s = scores[int(i/100 * len(scores))]
print("%.9f" % s, end=", ")
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path,
steps_per_epoch=args.steps_per_epoch,
redirected_tensor_path=args.redirected_tensor_path,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForTrain(
dit_path=args.dit_path,
learning_rate=args.learning_rate,
train_architecture=args.train_architecture,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
pretrained_lora_path=args.pretrained_lora_path,
)
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=os.path.join(args.output_path, "swanlog"),
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
args = parse_args()
if args.task == "data_process":
data_process(args)
elif args.task == "train":
train(args)

View File

@@ -1,41 +0,0 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=0
)
save_video(video, "video_slow.mp4", fps=15, quality=5)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=100
)
save_video(video, "video_fast.mp4", fps=15, quality=5)

View File

@@ -44,28 +44,11 @@ class LitModel(pl.LightningModule):
def configure_model(self): def configure_model(self):
tp_mesh = self.device_mesh["tensor_parallel"] tp_mesh = self.device_mesh["tensor_parallel"]
plan = {
"text_embedding.0": ColwiseParallel(),
"text_embedding.2": RowwiseParallel(),
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
"text_embedding.0": ColwiseParallel(),
"text_embedding.2": RowwiseParallel(),
"blocks.0": PrepareModuleInput(
input_layouts=(Replicate(), None, None, None),
desired_input_layouts=(Replicate(), None, None, None),
),
"head": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Replicate(), None),
use_local_output=True,
)
}
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
for block_id, block in enumerate(self.pipe.dit.blocks): for block_id, block in enumerate(self.pipe.dit.blocks):
layer_tp_plan = { layer_tp_plan = {
"self_attn": PrepareModuleInput( "self_attn": PrepareModuleInput(
input_layouts=(Shard(1), Replicate()), input_layouts=(Replicate(), Replicate()),
desired_input_layouts=(Shard(1), Shard(0)), desired_input_layouts=(Replicate(), Shard(0)),
), ),
"self_attn.q": SequenceParallel(), "self_attn.q": SequenceParallel(),
"self_attn.k": SequenceParallel(), "self_attn.k": SequenceParallel(),
@@ -76,11 +59,11 @@ class LitModel(pl.LightningModule):
input_layouts=(Shard(1), Shard(1), Shard(1)), input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(2), Shard(2), Shard(2)), desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
), ),
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()), "self_attn.o": ColwiseParallel(output_layouts=Replicate()),
"cross_attn": PrepareModuleInput( "cross_attn": PrepareModuleInput(
input_layouts=(Shard(1), Replicate()), input_layouts=(Replicate(), Replicate()),
desired_input_layouts=(Shard(1), Replicate()), desired_input_layouts=(Replicate(), Replicate()),
), ),
"cross_attn.q": SequenceParallel(), "cross_attn.q": SequenceParallel(),
"cross_attn.k": SequenceParallel(), "cross_attn.k": SequenceParallel(),
@@ -91,18 +74,10 @@ class LitModel(pl.LightningModule):
input_layouts=(Shard(1), Shard(1), Shard(1)), input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(2), Shard(2), Shard(2)), desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
), ),
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False), "cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
"ffn.0": ColwiseParallel(input_layouts=Shard(1)), "ffn.0": ColwiseParallel(),
"ffn.2": RowwiseParallel(output_layouts=Replicate()), "ffn.2": RowwiseParallel(),
"norm1": SequenceParallel(use_local_output=True),
"norm2": SequenceParallel(use_local_output=True),
"norm3": SequenceParallel(use_local_output=True),
"gate": PrepareModuleInput(
input_layouts=(Shard(1), Replicate(), Replicate()),
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
)
} }
parallelize_module( parallelize_module(
module=block, module=block,
@@ -121,6 +96,7 @@ class LitModel(pl.LightningModule):
save_video(video, output_path, fps=15, quality=5) save_video(video, output_path, fps=15, quality=5)
if __name__ == "__main__": if __name__ == "__main__":
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(

View File

@@ -1,58 +0,0 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
import torch.distributed as dist
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
[
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
],
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
)
dist.init_process_group(
backend="nccl",
init_method="env://",
)
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
pipe = WanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device=f"cuda:{dist.get_rank()}",
use_usp=True if dist.get_world_size() > 1 else False)
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Text-to-video
video = pipe(
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=0, tiled=True
)
if dist.get_rank() == 0:
save_video(video, "video1.mp4", fps=25, quality=5)

View File

@@ -1,42 +0,0 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# Image-to-video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
input_image=image,
# You can input `end_image=xxx` to control the last frame of the video.
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -1,40 +0,0 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/control_video.mp4"
)
# Control-to-video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
control_video=control_video, height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -14,7 +14,7 @@ else:
setup( setup(
name="diffsynth", name="diffsynth",
version="1.1.7", version="1.1.2",
description="Enjoy the magic of Diffusion models!", description="Enjoy the magic of Diffusion models!",
author="Artiprocher", author="Artiprocher",
packages=find_packages(), packages=find_packages(),

View File

@@ -1,241 +0,0 @@
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
import lightning as pl
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
from diffsynth.pipelines.flux_image import lets_dance_flux
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class LightningModel(LightningModelForT2ILoRA):
def __init__(
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
if quantize is None:
model_manager.load_models(pretrained_weights)
else:
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.reference_embedder = FluxReferenceEmbedder()
self.pipe.reference_embedder.init()
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.pipe.dit.requires_grad_(True)
self.pipe.dit.train()
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
# lora_alpha=lora_alpha,
# lora_target_modules=lora_target_modules,
# init_lora_weights=init_lora_weights,
# pretrained_lora_path=pretrained_lora_path,
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
# )
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Reference image
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
reference_embedder=self.pipe.reference_embedder,
hidden_states_ref=hidden_states_ref,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_2_path",
type=str,
default=None,
required=True,
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
)
parser.add_argument(
"--pretrained_dit_path",
type=str,
default=None,
required=True,
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
)
parser.add_argument(
"--pretrained_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--align_to_opensource_format",
default=False,
action="store_true",
help="Whether to export lora files aligned with other opensource format.",
)
parser.add_argument(
"--quantize",
type=str,
default=None,
choices=["float8_e4m3fn"],
help="Whether to use quantization when training the model, and in which format.",
)
parser.add_argument(
"--preset_lora_path",
type=str,
default=None,
help="Preset LoRA path.",
)
parser = add_general_parsers(parser)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = LightningModel(
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
preset_lora_path=args.preset_lora_path,
learning_rate=args.learning_rate,
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
pretrained_lora_path=args.pretrained_lora_path,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)
# dataset and data loader
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
height=512, width=512,
),
],
dataset_weight=(4, 1, 4, 1),
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision=args.precision,
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=None,
)
trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -1,248 +0,0 @@
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
import lightning as pl
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
from diffsynth.pipelines.flux_image import lets_dance_flux
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class LightningModel(LightningModelForT2ILoRA):
def __init__(
self,
torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None,
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None,
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
if quantize is None:
model_manager.load_models(pretrained_weights)
else:
model_manager.load_models(pretrained_weights[1:])
model_manager.load_model(pretrained_weights[0], torch_dtype=quantize)
if preset_lora_path is not None:
preset_lora_path = preset_lora_path.split(",")
for path in preset_lora_path:
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.reference_embedder = FluxReferenceEmbedder()
self.pipe.reference_embedder.init()
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.pipe.dit.requires_grad_(True)
self.pipe.dit.train()
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
# lora_alpha=lora_alpha,
# lora_target_modules=lora_target_modules,
# init_lora_weights=init_lora_weights,
# pretrained_lora_path=pretrained_lora_path,
# state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
# )
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Reference image
hidden_states_ref = self.pipe.vae_encoder(image_ref.to(dtype=self.pipe.torch_dtype, device=self.device))
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
reference_embedder=self.pipe.reference_embedder,
hidden_states_ref=hidden_states_ref,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_2_path",
type=str,
default=None,
required=True,
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
)
parser.add_argument(
"--pretrained_dit_path",
type=str,
default=None,
required=True,
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
)
parser.add_argument(
"--pretrained_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--align_to_opensource_format",
default=False,
action="store_true",
help="Whether to export lora files aligned with other opensource format.",
)
parser.add_argument(
"--quantize",
type=str,
default=None,
choices=["float8_e4m3fn"],
help="Whether to use quantization when training the model, and in which format.",
)
parser.add_argument(
"--preset_lora_path",
type=str,
default=None,
help="Preset LoRA path.",
)
parser.add_argument(
"--num_nodes",
type=int,
default=1,
help="Num nodes.",
)
parser = add_general_parsers(parser)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = LightningModel(
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
pretrained_weights=[args.pretrained_dit_path, args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_vae_path],
preset_lora_path=args.preset_lora_path,
learning_rate=args.learning_rate,
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
pretrained_lora_path=args.pretrained_lora_path,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)
# dataset and data loader
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
height=512, width=512,
),
],
dataset_weight=(4, 1, 4, 1),
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
num_nodes=args.num_nodes,
precision=args.precision,
strategy="ddp",
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=None,
)
trainer.fit(model=model, train_dataloaders=train_loader)