Compare commits

..

48 Commits

Author SHA1 Message Date
Zhongjie Duan
2b72ae0e56 Merge pull request #568 from mi804/nexus-gen
update nexus-gen readme
2025-05-13 14:11:25 +08:00
mi804
90588dcf97 update nexus-gen readme 2025-05-13 11:42:01 +08:00
xuyixuan.xyx
91fbb24e17 refine training 2025-05-12 14:19:00 +08:00
xuyixuan.xyx
f17558a4c4 train 2025-05-07 11:22:13 +08:00
Artiprocher
290ec469ca train 2025-05-06 17:54:32 +08:00
Artiprocher
1ed676b076 train 2025-05-06 17:53:56 +08:00
Artiprocher
f7737aff98 nexus-gen 2025-04-30 17:09:15 +08:00
Zhongjie Duan
ef2a7abad4 Step1x vram (#556)
* support step1x vram management
2025-04-28 10:13:20 +08:00
Zhongjie Duan
32f630ff5f Merge pull request #555 from modelscope/step1x
support step1x
2025-04-27 20:40:43 +08:00
Artiprocher
109a0a0d49 support step1x 2025-04-27 20:37:43 +08:00
Zhongjie Duan
4f01b37a2a Merge pull request #553 from modelscope/flex
Flex
2025-04-25 12:24:18 +08:00
Artiprocher
cc6306136c flex full support 2025-04-25 12:23:29 +08:00
Artiprocher
419ace37f3 flex full support 2025-04-25 11:32:13 +08:00
Artiprocher
ccf24c363f flex control 2025-04-24 19:18:54 +08:00
Artiprocher
b7a1ac6671 flex t2i 2025-04-24 14:51:40 +08:00
Zhongjie Duan
e54c0a8468 Merge pull request #548 from CD22104/main
liblib-controlnet
2025-04-22 14:54:16 +08:00
xuyixuan.xyx
5f4cb32255 liblib-controlnet 2025-04-22 13:45:49 +08:00
Zhongjie Duan
7b6cf39618 Merge pull request #544 from modelscope/Artiprocher-patch-1
Update train_wan_t2v.py
2025-04-17 15:39:44 +08:00
Zhongjie Duan
bf81de0c88 Update train_wan_t2v.py 2025-04-17 15:37:30 +08:00
Zhongjie Duan
b36cad6929 Merge pull request #543 from modelscope/wan-flf2v
bugfix
2025-04-17 15:24:36 +08:00
Zhongjie Duan
b161bd6dfd bugfix 2025-04-17 15:23:46 +08:00
Zhongjie Duan
538cfcbb77 Merge pull request #541 from modelscope/wan-flf2v
Wan flf2v
2025-04-17 14:51:08 +08:00
Artiprocher
a4105d2c0e support wan-flf2v 2025-04-17 14:48:55 +08:00
Artiprocher
553b341f5f support wan-flf2v 2025-04-17 14:47:55 +08:00
Zhongjie Duan
e9e24b8cf1 Merge pull request #537 from CD22104/main
issue523
2025-04-16 15:53:39 +08:00
CD22104
1b693d0028 issue523 2025-04-16 15:49:52 +08:00
Zhongjie Duan
a4c3c07229 Merge pull request #536 from modelscope/wan-vace-quant
support vace quant
2025-04-16 10:43:14 +08:00
Artiprocher
6b24748c80 support vace quant 2025-04-16 10:29:21 +08:00
Zhongjie Duan
8f2f8646eb Merge pull request #526 from mohui37/main
Update train_wan_t2v.py
2025-04-16 09:55:19 +08:00
Zhongjie Duan
e3ac438f5a Merge pull request #533 from modelscope/wan-vace
vace
2025-04-15 18:47:36 +08:00
Artiprocher
b731628112 vace 2025-04-15 17:52:25 +08:00
mohui37
0dc56d9dcc Update train_wan_t2v.py
在应用itv的管道处理数据时有bug,提交修复
2025-04-11 17:05:40 +08:00
Zhongjie Duan
b925b402e2 Merge pull request #522 from modelscope/Artiprocher-patch-1
Update README.md
2025-04-10 11:42:32 +08:00
Zhongjie Duan
61d9653536 Update README.md 2025-04-10 11:42:18 +08:00
Zhongjie Duan
53f01e72e6 Update setup.py 2025-04-09 15:38:17 +08:00
Zhongjie Duan
55e5e373dd Update publish.yaml 2025-04-09 15:37:46 +08:00
Zhongjie Duan
4a0921ada1 Update requirements.txt 2025-04-09 15:37:16 +08:00
Zhongjie Duan
5129d3dc52 Update setup.py 2025-04-09 15:34:02 +08:00
Zhongjie Duan
ee9bab80f2 Update requirements.txt 2025-04-09 15:33:21 +08:00
Zhongjie Duan
cd8884c9ef Update setup.py 2025-04-09 15:27:36 +08:00
Zhongjie Duan
46744362de Update requirements.txt 2025-04-09 15:26:13 +08:00
Zhongjie Duan
0f0cdc3afc Update setup.py 2025-04-09 15:15:18 +08:00
Zhongjie Duan
a33c63af87 Merge pull request #518 from modelscope/wan-fun
Wan fun
2025-04-08 19:25:12 +08:00
Artiprocher
3cc9764bc9 support more wan models 2025-04-08 19:22:53 +08:00
Artiprocher
f6c6e3c640 support more wan models 2025-04-08 17:19:54 +08:00
Artiprocher
60a9db706e support more wan models 2025-04-08 17:07:10 +08:00
lzw478614@alibaba-inc.com
a98700feb2 support wan-fun-inp generating 2025-04-06 22:55:42 +08:00
lzw478614@alibaba-inc.com
5418ca781e support load wan2.1-fun-inp-1.3B and 14B model 2025-04-03 16:37:59 +08:00
33 changed files with 5796 additions and 191 deletions

View File

@@ -20,7 +20,7 @@ jobs:
with:
python-version: '3.10'
- name: Install wheel
run: pip install wheel && pip install -r requirements.txt
run: pip install wheel==0.44.0 && pip install -r requirements.txt
- name: Build DiffSynth
run: python setup.py sdist bdist_wheel
- name: Publish package to PyPI

View File

@@ -42,6 +42,12 @@ Until now, DiffSynth-Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News
- **May 1, 2025** 🔥🔥🔥 We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models.
- Paper: [Nexus-Gen: A Unified Model for Image Understanding, Generation, and Editing](https://arxiv.org/pdf/2504.21356)
- Github Repo: https://github.com/modelscope/Nexus-Gen
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-Gen), [HuggingFace](https://huggingface.co/modelscope/Nexus-Gen)
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.

View File

@@ -59,6 +59,10 @@ from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_vace import VaceWanModel
from ..models.step1x_connector import Qwen2Connector
model_loader_configs = [
@@ -96,6 +100,7 @@ model_loader_configs = [
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
@@ -106,6 +111,7 @@ model_loader_configs = [
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "43ad5aaa27dd4ee01b832ed16773fa52", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
@@ -120,11 +126,19 @@ model_loader_configs = [
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
(None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -140,6 +154,7 @@ huggingface_model_loader_configs = [
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.

View File

@@ -320,6 +320,8 @@ class FluxControlNetStateDictConverter:
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs

View File

@@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False):
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.x_embedder = torch.nn.Linear(input_dim, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
self.input_dim = input_dim
def patchify(self, hidden_states):
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
pass
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
return state_dict_, {"disable_guidance_embedder": True}
elif "blocks.8.attn.norm_k_a.weight" not in state_dict_:
return state_dict_, {"input_dim": 196, "num_blocks": 8}
else:
return state_dict_

168
diffsynth/models/qwenvl.py Normal file
View File

@@ -0,0 +1,168 @@
import torch
class Qwen25VL_7b_Embedder(torch.nn.Module):
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
super(Qwen25VL_7b_Embedder, self).__init__()
self.max_length = max_length
self.dtype = dtype
self.device = device
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dtype,
).to(torch.cuda.current_device())
self.model.requires_grad_(False)
self.processor = AutoProcessor.from_pretrained(
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
)
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
Here are examples of how to transform or refine prompts:
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
User Prompt:'''
self.prefix = Qwen25VL_7b_PREFIX
@staticmethod
def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"):
return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device)
def forward(self, caption, ref_images):
text_list = caption
embs = torch.zeros(
len(text_list),
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
)
hidden_states = torch.zeros(
len(text_list),
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
)
masks = torch.zeros(
len(text_list),
self.max_length,
dtype=torch.long,
device=torch.cuda.current_device(),
)
input_ids_list = []
attention_mask_list = []
emb_list = []
def split_string(s):
s = s.replace("", '"').replace("", '"').replace("'", '''"''') # use english quotes
result = []
in_quotes = False
temp = ""
for idx,char in enumerate(s):
if char == '"' and idx>155:
temp += char
if not in_quotes:
result.append(temp)
temp = ""
in_quotes = not in_quotes
continue
if in_quotes:
if char.isspace():
pass # have space token
result.append("" + char + "")
else:
temp += char
if temp:
result.append(temp)
return result
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
messages = [{"role": "user", "content": []}]
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
messages[0]["content"].append({"type": "image", "image": imgs})
# 再添加 text
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
)
image_inputs = [imgs]
inputs = self.processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
old_inputs_ids = inputs.input_ids
text_split_list = split_string(text)
token_list = []
for text_each in text_split_list:
txt_inputs = self.processor(
text=text_each,
images=None,
videos=None,
padding=True,
return_tensors="pt",
)
token_each = txt_inputs.input_ids
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
token_each = token_each[:, 1:-1]
token_list.append(token_each)
else:
token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0)
.to("cuda")
)
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
outputs = self.model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"),
image_grid_thw=inputs.image_grid_thw.to("cuda"),
output_hidden_states=True,
)
emb = outputs["hidden_states"][-1]
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
: self.max_length
]
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long,
device=torch.cuda.current_device(),
)
return embs, masks

View File

@@ -0,0 +1,683 @@
from typing import Optional
import torch, math
import torch.nn
from einops import rearrange
from torch import nn
from functools import partial
from einops import rearrange
def attention(q, k, v, attn_mask, mode="torch"):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "b n s d -> b s (n d)")
return x
class MLP(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
device=None,
dtype=None,
):
super().__init__()
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = (bias, bias)
drop_probs = (drop, drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_channels, device=device, dtype=dtype)
if norm_layer is not None
else nn.Identity()
)
self.fc2 = linear_layer(
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class TextProjection(nn.Module):
"""
Projects text embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.linear_1 = nn.Linear(
in_features=in_channels,
out_features=hidden_size,
bias=True,
**factory_kwargs,
)
self.act_1 = act_layer()
self.linear_2 = nn.Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=True,
**factory_kwargs,
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer,
frequency_embedding_size=256,
max_period=10000,
out_size=None,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(
t, self.frequency_embedding_size, self.max_period
).type(self.mlp[0].weight.dtype) # type: ignore
t_emb = self.mlp(t_freq)
return t_emb
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
def get_activation_layer(act_type):
"""get activation layer
Args:
act_type (str): the activation type
Returns:
torch.nn.functional: the activation layer
"""
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
class IndividualTokenRefinerBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.need_CA = need_CA
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.self_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
self.mlp = MLP(
in_channels=hidden_size,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop_rate,
**factory_kwargs,
)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
if self.need_CA:
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
y: torch.Tensor = None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
if self.need_CA:
x = self.cross_attnblock(x, c, attn_mask, y)
# FFN Layer
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
return x
class CrossAttnBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.norm1_2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.self_attn_q = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.self_attn_kv = nn.Linear(
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
y: torch.Tensor=None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
norm_y = self.norm1_2(y)
q = self.self_attn_q(norm_x)
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
kv = self.self_attn_kv(norm_y)
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
return x
class IndividualTokenRefiner(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA:bool=False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.need_CA = need_CA
self.blocks = nn.ModuleList(
[
IndividualTokenRefinerBlock(
hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
need_CA=self.need_CA,
**factory_kwargs,
)
for _ in range(depth)
]
)
def forward(
self,
x: torch.Tensor,
c: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
y:torch.Tensor=None,
):
self_attn_mask = None
if mask is not None:
batch_size = mask.shape[0]
seq_len = mask.shape[1]
mask = mask.to(x.device)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
1, 1, seq_len, 1
)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
# avoids self-attention weight being NaN for padding tokens
self_attn_mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, self_attn_mask,y)
return x
class SingleTokenRefiner(torch.nn.Module):
"""
A single token refiner block for llm text embedding refine.
"""
def __init__(
self,
in_channels,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA:bool=False,
attn_mode: str = "torch",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.attn_mode = attn_mode
self.need_CA = need_CA
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
self.input_embedder = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
if self.need_CA:
self.input_embedder_CA = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
# Build timestep embedding layer
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
# Build context embedding layer
self.c_embedder = TextProjection(
in_channels, hidden_size, act_layer, **factory_kwargs
)
self.individual_token_refiner = IndividualTokenRefiner(
hidden_size=hidden_size,
heads_num=heads_num,
depth=depth,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
need_CA=need_CA,
**factory_kwargs,
)
def forward(
self,
x: torch.Tensor,
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
y: torch.LongTensor=None,
):
timestep_aware_representations = self.t_embedder(t)
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
context_aware_representations = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
if self.need_CA:
y = self.input_embedder_CA(y)
x = self.individual_token_refiner(x, c, mask, y)
else:
x = self.individual_token_refiner(x, c, mask)
return x
class Qwen2Connector(torch.nn.Module):
def __init__(
self,
# biclip_dim=1024,
in_channels=3584,
hidden_size=4096,
heads_num=32,
depth=2,
need_CA=False,
device=None,
dtype=torch.bfloat16,
):
super().__init__()
factory_kwargs = {"device": device, "dtype":dtype}
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
self.global_proj_out=nn.Linear(in_channels,768)
self.scale_factor = nn.Parameter(torch.zeros(1))
with torch.no_grad():
self.scale_factor.data += -(1 - 0.09)
def forward(self, x,t,mask):
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
x_mean = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
global_out=self.global_proj_out(x_mean)
encoder_hidden_states = self.S(x,t,mask)
return encoder_hidden_states,global_out
@staticmethod
def state_dict_converter():
return Qwen2ConnectorStateDictConverter()
class Qwen2ConnectorStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith("connector."):
name_ = name[len("connector."):]
state_dict_[name_] = param
return state_dict_

View File

@@ -36,6 +36,8 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x,tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
@@ -221,7 +223,7 @@ class DiTBlock(nn.Module):
class MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim):
def __init__(self, in_dim, out_dim, has_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
nn.LayerNorm(in_dim),
@@ -230,8 +232,13 @@ class MLP(torch.nn.Module):
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim)
)
self.has_pos_emb = has_pos_emb
if has_pos_emb:
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
def forward(self, x):
if self.has_pos_emb:
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
return self.proj(x)
@@ -264,6 +271,7 @@ class WanModel(torch.nn.Module):
num_heads: int,
num_layers: int,
has_image_input: bool,
has_image_pos_emb: bool = False,
):
super().__init__()
self.dim = dim
@@ -294,7 +302,8 @@ class WanModel(torch.nn.Module):
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
self.has_image_pos_emb = has_image_pos_emb
def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
@@ -451,6 +460,7 @@ class WanModelStateDictConverter:
return state_dict_, config
def from_civitai(self, state_dict):
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = {
"has_image_input": False,
@@ -493,6 +503,77 @@ class WanModelStateDictConverter:
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"has_image_pos_emb": True
}
else:
config = {}
return state_dict, config

View File

@@ -0,0 +1,44 @@
import torch
import torch.nn as nn
from .wan_video_dit import sinusoidal_embedding_1d
class WanMotionControllerModel(torch.nn.Module):
def __init__(self, freq_dim=256, dim=1536):
super().__init__()
self.freq_dim = freq_dim
self.linear = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6),
)
def forward(self, motion_bucket_id):
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
emb = self.linear(emb)
return emb
def init(self):
state_dict = self.linear[-1].state_dict()
state_dict = {i: state_dict[i] * 0 for i in state_dict}
self.linear[-1].load_state_dict(state_dict)
@staticmethod
def state_dict_converter():
return WanMotionControllerModelDictConverter()
class WanMotionControllerModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -0,0 +1,77 @@
import torch
from .wan_video_dit import DiTBlock
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = torch.nn.Linear(self.dim, self.dim)
self.after_proj = torch.nn.Linear(self.dim, self.dim)
def forward(self, c, x, context, t_mod, freqs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, context, t_mod, freqs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class VaceWanModel(torch.nn.Module):
def __init__(
self,
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
vace_in_dim=96,
patch_size=(1, 2, 2),
has_image_input=False,
dim=1536,
num_heads=12,
ffn_dim=8960,
eps=1e-6,
):
super().__init__()
self.vace_layers = vace_layers
self.vace_in_dim = vace_in_dim
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# vace blocks
self.vace_blocks = torch.nn.ModuleList([
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, vace_context, context, t_mod, freqs):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
dim=1) for u in c
])
for block in self.vace_blocks:
c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1]
return hints
@staticmethod
def state_dict_converter():
return VaceWanModelDictConverter()
class VaceWanModelDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
return state_dict_

View File

@@ -1,4 +1,5 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..models.step1x_connector import Qwen2Connector
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
@@ -32,105 +33,112 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.infinityou_processor: InfinitYou = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'qwenvl', 'step1x_connector']
def enable_vram_management(self, num_persistent_param_in_dit=None):
dtype = next(iter(self.text_encoder_1.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
T5DenseActDense: AutoWrappedModule,
T5DenseGatedActDense: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae_decoder.parameters())).dtype
enable_vram_management(
self.vae_decoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
dtype = next(iter(self.vae_encoder.parameters())).dtype
enable_vram_management(
self.vae_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.text_encoder_1 is not None:
dtype = next(iter(self.text_encoder_1.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.text_encoder_2 is not None:
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
T5DenseActDense: AutoWrappedModule,
T5DenseGatedActDense: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.vae_decoder is not None:
dtype = next(iter(self.vae_decoder.parameters())).dtype
enable_vram_management(
self.vae_decoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
if self.vae_encoder is not None:
dtype = next(iter(self.vae_encoder.parameters())).dtype
enable_vram_management(
self.vae_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
@@ -167,6 +175,10 @@ class FluxImagePipeline(BasePipeline):
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if self.image_proj_model is not None:
self.infinityou_processor = InfinitYou(device=self.device)
# Step1x
self.qwenvl = model_manager.fetch_model("qwenvl")
self.step1x_connector = model_manager.fetch_model("step1x_connector")
@staticmethod
@@ -190,11 +202,14 @@ class FluxImagePipeline(BasePipeline):
return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, image_emb=None):
if (self.text_encoder_1 is not None and self.text_encoder_2 is not None) or (image_emb is not None):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, image_emb=image_emb
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
else:
return {}
def prepare_extra_input(self, latents=None, guidance=1.0):
@@ -343,13 +358,13 @@ class FluxImagePipeline(BasePipeline):
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb=None):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length, image_emb=image_emb)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@@ -360,6 +375,46 @@ class FluxImagePipeline(BasePipeline):
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
else:
return {}, controlnet_image
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
if self.dit.input_dim == 196:
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
else:
flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype)
flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if flex_inpaint_mask is None:
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
else:
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
if flex_control_image is None:
flex_control_image = torch.zeros_like(latents)
else:
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
else:
flex_kwargs = {}
return flex_kwargs
def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
if image is None:
return {}, {}
self.load_models_to_device(["qwenvl", "vae_encoder"])
captions = [prompt, negative_prompt]
ref_images = [image, image]
embs, masks = self.qwenvl(captions, ref_images)
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
image = self.encode_image(image)
return {"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}, {"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}
@torch.no_grad()
@@ -377,6 +432,7 @@ class FluxImagePipeline(BasePipeline):
height=1024,
width=1024,
seed=None,
image_emb=None,
# Steps
num_inference_steps=30,
# local prompts
@@ -398,6 +454,14 @@ class FluxImagePipeline(BasePipeline):
# InfiniteYou
infinityou_id_image=None,
infinityou_guidance=1.0,
# Flex
flex_inpaint_image=None,
flex_inpaint_mask=None,
flex_control_image=None,
flex_control_strength=0.5,
flex_control_stop=0.5,
# Step1x
step1x_reference_image=None,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -420,7 +484,7 @@ class FluxImagePipeline(BasePipeline):
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
# Prompt
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale, image_emb)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
@@ -436,20 +500,26 @@ class FluxImagePipeline(BasePipeline):
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Flex
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength=flex_control_strength, flex_control_stop=flex_control_stop, **tiler_kwargs)
# Step1x
step1x_kwargs_posi, step1x_kwargs_nega = self.prepare_step1x_kwargs(prompt, negative_prompt, image=step1x_reference_image)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(['dit', 'controlnet'])
self.load_models_to_device(['dit', 'controlnet', 'step1x_connector'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_posi,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -464,9 +534,9 @@ class FluxImagePipeline(BasePipeline):
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -586,6 +656,7 @@ class TeaCache:
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
step1x_connector: Qwen2Connector = None,
hidden_states=None,
timestep=None,
prompt_emb=None,
@@ -602,7 +673,14 @@ def lets_dance_flux(
ipadapter_kwargs_list={},
id_emb=None,
infinityou_guidance=None,
flex_condition=None,
flex_uncondition=None,
flex_control_stop_timestep=None,
step1x_llm_embedding=None,
step1x_mask=None,
step1x_reference_latents=None,
tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs
):
if tiled:
@@ -652,6 +730,18 @@ def lets_dance_flux(
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
# Flex
if flex_condition is not None:
if timestep.tolist()[0] >= flex_control_stop_timestep:
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
else:
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
# Step1x
if step1x_llm_embedding is not None:
prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
@@ -663,6 +753,14 @@ def lets_dance_flux(
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
# Step1x
if step1x_reference_latents is not None:
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
step1x_reference_latents = dit.patchify(step1x_reference_latents)
image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
hidden_states = dit.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
@@ -677,20 +775,32 @@ def lets_dance_flux(
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else:
tea_cache_update = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tea_cache_update:
hidden_states = tea_cache.update(hidden_states)
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -699,14 +809,21 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
@@ -717,6 +834,11 @@ def lets_dance_flux(
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
# Step1x
if step1x_reference_latents is not None:
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states

View File

@@ -4,6 +4,7 @@ from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import WanPrompter
@@ -18,6 +19,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -31,7 +33,9 @@ class WanVideoPipeline(BasePipeline):
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
@@ -122,6 +126,40 @@ class WanVideoPipeline(BasePipeline):
computation_device=self.device,
),
)
if self.motion_controller is not None:
dtype = next(iter(self.motion_controller.parameters())).dtype
enable_vram_management(
self.motion_controller,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
if self.vace is not None:
enable_vram_management(
self.vace,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
@@ -134,6 +172,8 @@ class WanVideoPipeline(BasePipeline):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
self.vace = model_manager.fetch_model("wan_video_vace")
@staticmethod
@@ -163,22 +203,50 @@ class WanVideoPipeline(BasePipeline):
return {"context": prompt_emb}
def encode_image(self, image, num_frames, height, width):
def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
if self.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=self.torch_dtype, device=self.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
y = y.to(dtype=self.torch_dtype, device=self.device)
return {"clip_feature": clip_context, "y": y}
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
control_video = self.preprocess_images(control_video)
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return latents
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
if control_video is not None:
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
else:
y = y[:, -16:]
y = torch.concat([control_latents, y], dim=1)
return {"clip_feature": clip_feature, "y": y}
def tensor2video(self, frames):
@@ -204,6 +272,62 @@ class WanVideoPipeline(BasePipeline):
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_motion_bucket_id(self, motion_bucket_id):
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
return {"motion_bucket_id": motion_bucket_id}
def prepare_vace_kwargs(
self,
latents,
vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
height=480, width=832, num_frames=81,
seed=None, rand_device="cpu",
tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
):
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
self.load_models_to_device(["vae"])
if vace_video is None:
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
else:
vace_video = self.preprocess_images(vace_video)
vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
if vace_mask is None:
vace_mask = torch.ones_like(vace_video)
else:
vace_mask = self.preprocess_images(vace_mask)
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
if vace_reference_image is None:
pass
else:
vace_reference_image = self.preprocess_images([vace_reference_image])
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = torch.concat((noise, latents), dim=2)
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
else:
return latents, {"vace_context": None, "vace_scale": vace_scale}
@torch.no_grad()
@@ -212,7 +336,13 @@ class WanVideoPipeline(BasePipeline):
prompt,
negative_prompt="",
input_image=None,
end_image=None,
input_video=None,
control_video=None,
vace_video=None,
vace_video_mask=None,
vace_reference_image=None,
vace_scale=1.0,
denoising_strength=1.0,
seed=None,
rand_device="cpu",
@@ -222,6 +352,7 @@ class WanVideoPipeline(BasePipeline):
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
motion_bucket_id=None,
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
@@ -263,13 +394,30 @@ class WanVideoPipeline(BasePipeline):
# Encode image
if input_image is not None and self.image_encoder is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.encode_image(input_image, num_frames, height, width)
image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs)
else:
image_emb = {}
# ControlNet
if control_video is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
# Motion Controller
if self.motion_controller is not None and motion_bucket_id is not None:
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
else:
motion_kwargs = {}
# Extra input
extra_input = self.prepare_extra_input(latents)
# VACE
latents, vace_kwargs = self.prepare_vace_kwargs(
latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
)
# TeaCache
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
@@ -278,20 +426,33 @@ class WanVideoPipeline(BasePipeline):
usp_kwargs = self.prepare_unified_sequence_parallel()
# Denoise
self.load_models_to_device(["dit"])
self.load_models_to_device(["dit", "motion_controller", "vace"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
noise_pred_posi = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller, vace=self.vace,
x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
)
if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
noise_pred_nega = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller, vace=self.vace,
x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
if vace_reference_image is not None:
latents = latents[:, :, 1:]
# Decode
self.load_models_to_device(['vae'])
@@ -358,13 +519,18 @@ class TeaCache:
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
motion_controller: WanMotionControllerModel = None,
vace: VaceWanModel = None,
x: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
vace_context = None,
vace_scale = 1.0,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
**kwargs,
):
if use_unified_sequence_parallel:
@@ -375,6 +541,8 @@ def model_fn_wan_video(
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
if dit.has_image_input:
@@ -395,6 +563,9 @@ def model_fn_wan_video(
tea_cache_update = tea_cache.check(dit, x, t_mod)
else:
tea_cache_update = False
if vace_context is not None:
vace_hints = vace(x, vace_context, context, t_mod, freqs)
# blocks
if use_unified_sequence_parallel:
@@ -403,8 +574,10 @@ def model_fn_wan_video(
if tea_cache_update:
x = tea_cache.update(x)
else:
for block in dit.blocks:
for block_id, block in enumerate(dit.blocks):
x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
if tea_cache is not None:
tea_cache.store(x)

View File

@@ -59,6 +59,7 @@ class FluxPrompter(BasePrompter):
positive=True,
device="cuda",
t5_sequence_length=512,
image_emb=None,
):
prompt = self.process_prompt(prompt, positive=positive)
@@ -66,7 +67,10 @@ class FluxPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
if image_emb is not None:
prompt_emb = image_emb
else:
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)

View File

@@ -0,0 +1,49 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models
from diffsynth.controlnets.processors import Annotator
import numpy as np
from PIL import Image
download_models(["FLUX.1-dev"])
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/ostris/Flex.2-preview/Flex.2-preview.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
image = pipe(
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
seed=0
)
image.save("image_1.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[200:400, 400:700] = 255
mask = Image.fromarray(mask)
mask.save("image_mask.jpg")
inpaint_image = image
image = pipe(
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
seed=4
)
image.save("image_2.jpg")
control_image = Annotator("canny")(image)
control_image.save("image_control.jpg")
image = pipe(
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
flex_control_image=control_image,
seed=4
)
image.save("image_3.jpg")

35
examples/step1x/step1x.py Normal file
View File

@@ -0,0 +1,35 @@
import torch
from diffsynth import FluxImagePipeline, ModelManager
from modelscope import snapshot_download
from PIL import Image
import numpy as np
snapshot_download("Qwen/Qwen2.5-VL-7B-Instruct", cache_dir="./models")
snapshot_download("stepfun-ai/Step1X-Edit", cache_dir="./models")
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/Qwen/Qwen2.5-VL-7B-Instruct",
"models/stepfun-ai/Step1X-Edit/step1x-edit-i1258.safetensors",
"models/stepfun-ai/Step1X-Edit/vae.safetensors",
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.enable_vram_management()
image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255)
image = pipe(
prompt="draw red flowers in Chinese ink painting style",
step1x_reference_image=image,
width=832, height=1248, cfg_scale=6,
seed=1,
)
image.save("image_1.jpg")
image = pipe(
prompt="add more flowers in Chinese ink painting style",
step1x_reference_image=image,
width=832, height=1248, cfg_scale=6,
seed=2,
)
image.save("image_2.jpg")

View File

@@ -10,20 +10,93 @@ cd DiffSynth-Studio
pip install -e .
```
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
## Model Zoo
|Developer|Name|Link|Scripts|
|-|-|-|-|
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|Wan Team|14B first-last-frame-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|[wan_14B_flf2v.py](./wan_14B_flf2v.py)|
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)|
Base model features
||Text-to-video|Image-to-video|End frame|Control|Reference image|
|-|-|-|-|-|-|
|1.3B text-to-video|✅|||||
|14B text-to-video|✅|||||
|14B image-to-video 480P||✅||||
|14B image-to-video 720P||✅||||
|14B first-last-frame-to-video 720P||✅|✅|||
|1.3B InP||✅|✅|||
|14B InP||✅|✅|||
|1.3B Control||||✅||
|14B Control||||✅||
|1.3B VACE||||✅|✅|
Adapter model compatibility
||1.3B text-to-video|1.3B InP|1.3B VACE|
|-|-|-|-|
|1.3B aesthetics LoRA|✅||✅|
|1.3B Highres-fix LoRA|✅||✅|
|1.3B ExVideo LoRA|✅||✅|
|1.3B Speed Control adapter|✅|✅|✅|
## VRAM Usage
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|-|-|-|-|-|
|torch.bfloat16|None (unlimited)|18.5s/it|48G||
|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G||
|torch.bfloat16|0|23.4s/it|10G||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|torch.float8_e4m3fn|0|24.0s/it|10G||
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
## Efficient Attention Implementation
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
## Inference
## Acceleration
### Wan-Video-1.3B-T2V
We support multiple acceleration solutions:
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
Required VRAM: 6G
```bash
pip install xfuser>=0.4.3
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
```
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
## Gallery
1.3B text-to-video.
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
@@ -31,50 +104,20 @@ Put sunglasses on the dog.
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
### Wan-Video-14B-T2V
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
We present a detailed table here. The model is tested on a single A100.
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|-|-|-|-|-|
|torch.bfloat16|None (unlimited)|18.5s/it|40G||
|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G||
|torch.bfloat16|0|23.4s/it|10G||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|torch.float8_e4m3fn|0|24.0s/it|10G||
14B text-to-video.
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
### Parallel Inference
1. Unified Sequence Parallel (USP)
```bash
pip install xfuser>=0.4.3
```
```bash
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
```
2. Tensor Parallel
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
### Wan-Video-14B-I2V
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39)
14B image-to-video.
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
14B first-last-frame-to-video
|First frame|Last frame|Video|
|-|-|-|
|![Image](https://github.com/user-attachments/assets/b0d8225b-aee0-4129-b8e5-58c8523221a6)|![Image](https://github.com/user-attachments/assets/2f0c9bc5-07e2-45fa-8320-53d63a4fd203)|https://github.com/user-attachments/assets/2a6a2681-622c-4512-b852-5f22e73830b1|
## Train
We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA:

View File

@@ -56,13 +56,16 @@ class TextVideoDataset(torch.utils.data.Dataset):
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
if first_frame is None:
first_frame = np.array(frame)
first_frame = frame
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
first_frame = v2.functional.center_crop(first_frame, output_size=(self.height, self.width))
first_frame = np.array(first_frame)
if self.is_i2v:
return frames, first_frame
@@ -140,7 +143,7 @@ class LightningModelForDataProcess(pl.LightningModule):
if "first_frame" in batch:
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
_, _, num_frames, height, width = video.shape
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
image_emb = self.pipe.encode_image(first_frame, None, num_frames, height, width)
else:
image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}

View File

@@ -0,0 +1,41 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=0
)
save_video(video, "video_slow.mp4", fps=15, quality=5)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=100
)
save_video(video, "video_fast.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,63 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("iic/VACE-Wan2.1-1.3B-Preview", local_dir="models/iic/VACE-Wan2.1-1.3B-Preview")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/iic/VACE-Wan2.1-1.3B-Preview/diffusion_pytorch_model.safetensors",
"models/iic/VACE-Wan2.1-1.3B-Preview/models_t5_umt5-xxl-enc-bf16.pth",
"models/iic/VACE-Wan2.1-1.3B-Preview/Wan2.1_VAE.pth",
],
torch_dtype=torch.bfloat16,
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
)
# Depth video -> Video
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
height=480, width=832, num_frames=81,
vace_video=control_video,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)
# Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
height=480, width=832, num_frames=81,
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video2.mp4", fps=15, quality=5)
# Depth video + Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
height=480, width=832, num_frames=81,
vace_video=control_video,
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video3.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,52 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("Wan-AI/Wan2.1-FLF2V-14B-720P", local_dir="models/Wan-AI/Wan2.1-FLF2V-14B-720P")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
["models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
torch_dtype=torch.float32, # Image Encoder is loaded with float32
)
model_manager.load_models(
[
[
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00001-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00002-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00003-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00004-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00005-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00006-of-00007.safetensors",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00007-of-00007.safetensors",
],
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-FLF2V-14B-720P/Wan2.1_VAE.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"]
)
# First and last frame to video
video = pipe(
prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=30,
input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)),
end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)),
height=960, width=960,
seed=1, tiled=True
)
save_video(video, "video.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,42 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# Image-to-video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
input_image=image,
# You can input `end_image=xxx` to control the last frame of the video.
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,40 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/control_video.mp4"
)
# Control-to-video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
control_video=control_video, height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

0
modeling/ar/__init__.py Normal file
View File

View File

@@ -0,0 +1,258 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_qwen2_5_vl.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class Qwen2_5_VLVisionConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=32,
hidden_size=3584,
hidden_act="silu",
intermediate_size=3420,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
tokens_per_second=4,
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
class Qwen2_5_VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
hidden_size (`int`, *optional*, defaults to 8192):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 29568):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 80):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 80):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2_5_VL`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self, ignore_keys={"mrope_section"})
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Qwen2_5_VLConfig"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,235 @@
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List, Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, VideoInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
fps: Union[List[float], float]
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"videos_kwargs": {"fps": 2.0},
}
class Qwen2_5_VLProcessor(ProcessorMixin):
r"""
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
videos: VideoInput = None,
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen2_5_VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if isinstance(fps, (int, float)):
second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
else:
videos_inputs = {}
video_grid_thw = None
if not isinstance(text, list):
text = [text]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
text[i] = text[i].replace(
self.image_token,
"<|placeholder|>" * (image_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
text[i] = text[i].replace(
self.video_token,
"<|placeholder|>" * (video_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def batch_decode_all2all(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
decoded = self.tokenizer.batch_decode(*args, **kwargs)
pattern = r'<\|vision_start\|>.*?<\|vision_end\|>'
decoded_with_image_tag = [re.sub(pattern, '<image>', d, flags=re.DOTALL) for d in decoded]
decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag]
return decoded_with_image_tag
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
return names_from_processor + ["second_per_grid_ts"]
__all__ = ["Qwen2_5_VLProcessor"]

View File

@@ -0,0 +1,64 @@
import torch
from diffsynth import ModelManager
from .flux_image_pipeline import FluxImagePipelineAll2All
class FluxDecoder:
def __init__(self, flux_all2all_modelpath, flux_path, device='cuda', torch_dtype=torch.bfloat16):
self.device = device
self.torch_dtype = torch_dtype
self.pipe, self.adapter = self.get_pipe(flux_all2all_modelpath, flux_path, device, torch_dtype)
def get_pipe(self, flux_all2all_modelpath, flux_path, device="cuda", torch_dtype=torch.bfloat16):
model_manager = ModelManager(torch_dtype=torch_dtype, device=device)
model_manager.load_models([
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
f"{flux_path}/FLUX/FLUX.1-dev/text_encoder_2",
f"{flux_path}/FLUX/FLUX.1-dev/ae.safetensors",
f"{flux_path}/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
state_dict = torch.load(flux_all2all_modelpath, weights_only=True, map_location='cpu')
adapter_states = ['0.weight', '0.bias', '1.weight', '1.bias', '3.weight', '3.bias', '4.weight', '4.bias']
adapter_state_dict = {}
for key in adapter_states:
adapter_state_dict[key] = state_dict.pop(key)
in_channel = 3584
out_channel = 4096
expand_ratio = 1
adapter = torch.nn.Sequential(torch.nn.Linear(in_channel, out_channel * expand_ratio),
torch.nn.LayerNorm(out_channel * expand_ratio), torch.nn.ReLU(),
torch.nn.Linear(out_channel * expand_ratio, out_channel),
torch.nn.LayerNorm(out_channel))
adapter.load_state_dict(adapter_state_dict)
adapter.to(device, dtype=torch_dtype)
pipe = FluxImagePipelineAll2All.from_model_manager(model_manager)
pipe.dit.load_state_dict(state_dict)
return pipe, adapter
@torch.no_grad()
def decode_image_embeds(self,
output_image_embeddings,
height=512,
width=512,
num_inference_steps=50,
seed=42,
negative_prompt="",
cfg_scale=1.0,
**pipe_kwargs):
output_image_embeddings = output_image_embeddings.to(device=self.device, dtype=self.torch_dtype)
image_embed = self.adapter(output_image_embeddings)
image = self.pipe(prompt="",
image_embed=image_embed,
num_inference_steps=num_inference_steps,
embedded_guidance=3.5,
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
height=height,
width=width,
seed=seed,
**pipe_kwargs)
return image

View File

@@ -0,0 +1,192 @@
from typing import List
from tqdm import tqdm
import torch
from diffsynth.models import ModelManager
from diffsynth.controlnets import ControlNetConfigUnit
from diffsynth.prompters.flux_prompter import FluxPrompter
from diffsynth.pipelines.flux_image import FluxImagePipeline, lets_dance_flux, TeaCache
class FluxPrompterAll2All(FluxPrompter):
def encode_prompt(
self,
prompt,
positive=True,
device="cuda",
t5_sequence_length=512,
clip_only=False
):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
if clip_only:
return None, pooled_prompt_emb, None
# T5
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
# text_ids
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
return prompt_emb, pooled_prompt_emb, text_ids
class FluxImagePipelineAll2All(FluxImagePipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.prompter = FluxPrompterAll2All()
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512, clip_only=False):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length, clip_only=clip_only
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
pipe = FluxImagePipelineAll2All(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
return pipe
def prepare_prompts(self, prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
if image_embed is not None:
image_embed = image_embed.to(self.torch_dtype)
prompt_emb_posi = self.encode_prompt("", positive=True, clip_only=True)
if len(image_embed.size()) == 2:
image_embed = image_embed.unsqueeze(0)
prompt_emb_posi['prompt_emb'] = image_embed
prompt_emb_posi['text_ids'] = torch.zeros(image_embed.shape[0], image_embed.shape[1], 3).to(device=self.device, dtype=self.torch_dtype)
else:
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@torch.no_grad()
def __call__(
self,
# Prompt
prompt,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
t5_sequence_length=512,
# Image
input_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
seed=None,
# image_embed
image_embed=None,
# Steps
num_inference_steps=30,
# local prompts
local_prompts=(),
masks=(),
mask_scales=(),
# ControlNet
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
# EliGen
eligen_entity_prompts=None,
eligen_entity_masks=None,
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
tiled=False,
tile_size=128,
tile_stride=64,
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
# Prompt
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, image_embed, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# Entity control
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
# IP-Adapter
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
)
# Inpaint
if enable_eligen_inpaint:
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
# Classifier-free guidance
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Iterate
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, **tiler_kwargs)
# Offload all models
self.load_models_to_device([])
return image

View File

@@ -1,7 +1,7 @@
torch>=2.0.0
torchvision
cupy-cuda12x
transformers==4.46.2
transformers==4.49.0
controlnet-aux==0.0.7
imageio
imageio[ffmpeg]
@@ -11,3 +11,4 @@ sentencepiece
protobuf
modelscope
ftfy
qwen_vl_utils

4
run_single.sh Normal file
View File

@@ -0,0 +1,4 @@
accelerate launch \
train.py \
--output_path models/nexus_v3 \
--steps_per_epoch 4000

View File

@@ -14,7 +14,7 @@ else:
setup(
name="diffsynth",
version="1.1.3",
version="1.1.7",
description="Enjoy the magic of Diffusion models!",
author="Artiprocher",
packages=find_packages(),

312
test.py Normal file
View File

@@ -0,0 +1,312 @@
from transformers import AutoConfig, AutoTokenizer
import torch, json, os, torchvision
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
from qwen_vl_utils import smart_resize
from PIL import Image
import numpy as np
from torchvision.transforms import v2
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path,
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None
):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path, skip_process=False):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
if skip_process:
return image
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
def __len__(self):
return self.steps_per_epoch
class NexusGenQwenVLEncoder(torch.nn.Module):
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
super().__init__()
model_config = AutoConfig.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
@staticmethod
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
def process_images(self, images=None):
if images is None:
return None
# resize input to max_pixels to avoid oom
for j in range(len(images)):
input_image = images[j]
input_w, input_h = input_image.size
resized_height, resized_width = smart_resize(
input_h,
input_w,
max_pixels=262640,
)
images[j] = input_image.resize((resized_width, resized_height))
return images
def forward(self, prompt, images=None, num_img_tokens=81):
messages = [
{
"role": "user",
"content": [{
"type": "text",
"text": prompt
},],
},
{
"role": "assistant",
"content": [{
"type": "text",
"text": self.t2i_template if images is None else self.i2i_template
},],
}
]
images = self.process_images(images)
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
if images is None:
images = [target_image]
else:
images = images + [target_image]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
inputs = self.processor(
text=[text],
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == self.model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
output_image_embeddings = output_image_embeddings.unsqueeze(0)
return output_image_embeddings
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
# state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
# pipe.dit.load_state_dict(state_dict, strict=False)
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
# adapter.load_state_dict(state_dict, strict=False)
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
sd = {}
for i in range(1, 6):
print(i)
sd.update(load_state_dict(f"models/nexus_v3/epoch-19/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16))
pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")})
qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")})
adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")})
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
),
],
dataset_weight=(4, 2, 1,),
steps_per_epoch=100000
)
torch.manual_seed(0)
for data_id, data in enumerate(dataset):
image_1 = data["image_1"]
image_2 = data["image_2"].cpu().float().permute(1, 2, 0).numpy()
image_2 = Image.fromarray(((image_2 / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
instruction = data["instruction"]
print(instruction)
if image_1 is None:
with torch.no_grad():
instruction = f"Generate an image according to the following description: {instruction}"
emb = qwenvl(instruction, images=None)
emb = adapter(emb)
image_3 = pipe("", image_emb=emb)
else:
with torch.no_grad():
instruction = f"<|vision_start|><|image_pad|><|vision_end|> {instruction}"
emb = qwenvl(instruction, images=[image_1])
emb = adapter(emb)
image_3 = pipe("", image_emb=emb)
if image_1 is not None:
image_1.save(f"data/output/{data_id}_1.jpg")
image_2.save(f"data/output/{data_id}_2.jpg")
image_3.save(f"data/output/{data_id}_3.jpg")
if data_id >= 100:
break
# with torch.no_grad():
# instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic"
# emb = qwenvl(instruction, images=None)
# emb = adapter(emb)
# image = pipe("", image_emb=emb)
# image.save("image_1.jpg")
# with torch.no_grad():
# instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression."
# emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
# emb = adapter(emb)
# image = pipe("", image_emb=emb)
# image.save("image_2.jpg")

421
train.py Normal file
View File

@@ -0,0 +1,421 @@
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
from diffsynth.models.lora import FluxLoRAConverter
import torch, os, argparse
from diffsynth.pipelines.flux_image import lets_dance_flux
from accelerate import Accelerator
from tqdm import tqdm
import torch, os, json, torchvision
from PIL import Image
from torchvision.transforms import v2
from transformers import AutoConfig, AutoTokenizer
import torch
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys
from qwen_vl_utils import smart_resize
from PIL import Image
import numpy as np
import lightning as pl
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class SingleTaskDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path,
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None
):
self.base_path = base_path
self.keys = keys
self.metadata = []
self.bad_data = []
self.height = height
self.width = width
self.random = random
self.steps_per_epoch = steps_per_epoch
self.image_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if metadata_path is None:
self.search_for_data("", report_data_log=True)
self.report_data_log()
else:
with open(metadata_path, "r", encoding="utf-8-sig") as f:
self.metadata = json.load(f)
def report_data_log(self):
print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.")
def dump_metadata(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=4)
def parse_json_file(self, absolute_path, relative_path):
data_list = []
with open(absolute_path, "r") as f:
metadata = json.load(f)
for image_1, image_2, instruction in self.keys:
image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None
image_2 = os.path.join(relative_path, metadata[image_2])
instruction = metadata[instruction]
data_list.append((image_1, image_2, instruction))
return data_list
def search_for_data(self, path, report_data_log=False):
now_path = os.path.join(self.base_path, path)
if os.path.isfile(now_path) and path.endswith(".json"):
try:
data_list = self.parse_json_file(now_path, os.path.dirname(path))
self.metadata.extend(data_list)
except:
self.bad_data.append(now_path)
elif os.path.isdir(now_path):
for sub_path in os.listdir(now_path):
self.search_for_data(os.path.join(path, sub_path))
if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)):
self.report_data_log()
def load_image(self, image_path, skip_process=False):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
if skip_process:
return image
image = self.image_process(image)
return image
def load_data(self, data_id):
image_1, image_2, instruction = self.metadata[data_id]
image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None
image_2 = self.load_image(image_2)
return {"image_1": image_1, "image_2": image_2, "instruction": instruction}
def __getitem__(self, data_id):
if self.random:
while True:
try:
data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata)
data = self.load_data(data_id)
return data
except:
continue
else:
return self.load_data(data_id)
def __len__(self):
return self.steps_per_epoch if self.random else len(self.metadata)
class MultiTaskDataset(torch.utils.data.Dataset):
def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000):
self.dataset_list = dataset_list
self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float)
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, data_id):
dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0]
data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0]
data = self.dataset_list[dataset_id][data_id]
return data
def __len__(self):
return self.steps_per_epoch
class NexusGenQwenVLEncoder(torch.nn.Module):
def __init__(self, model_path, torch_dtype="auto", device="cpu"):
super().__init__()
model_config = AutoConfig.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, config=model_config, trust_remote_code=True, torch_dtype=torch_dtype, device_map=device)
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
self.t2i_template = "Here is an image based on the description: <|vision_start|><|image_pad|><|vision_end|>"
self.i2i_template = "Here is the image: <|vision_start|><|image_pad|><|vision_end|>"
@staticmethod
def from_pretrained(model_path, torch_dtype="auto", device="cpu"):
return NexusGenQwenVLEncoder(model_path, torch_dtype=torch_dtype, device=device).eval()
def process_images(self, images=None):
if images is None:
return None
# resize input to max_pixels to avoid oom
for j in range(len(images)):
input_image = images[j]
input_w, input_h = input_image.size
resized_height, resized_width = smart_resize(
input_h,
input_w,
max_pixels=262640,
)
images[j] = input_image.resize((resized_width, resized_height))
return images
def forward(self, prompt, images=None, num_img_tokens=81):
messages = [
{
"role": "user",
"content": [{
"type": "text",
"text": prompt
},],
},
{
"role": "assistant",
"content": [{
"type": "text",
"text": self.t2i_template if images is None else self.i2i_template
},],
}
]
images = self.process_images(images)
target_image = Image.fromarray(np.zeros((252, 252, 3), dtype=np.uint8))
if images is None:
images = [target_image]
else:
images = images + [target_image]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
inputs = self.processor(
text=[text],
images=images,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
input_embeds = self.model.model.embed_tokens(inputs['input_ids'])
image_embeds = self.model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
input_image_embeds = image_embeds[:-num_img_tokens]
image_mask = inputs['input_ids'] == self.model.config.image_token_id
indices = image_mask.cumsum(dim=1)
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
position_ids, _ = self.model.get_rope_index(inputs['input_ids'],
inputs['image_grid_thw'],
attention_mask=inputs['attention_mask'])
position_ids = position_ids.contiguous()
outputs = self.model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
output_image_embeddings = outputs.image_embeddings[:, :-1, :] # shift right
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
output_image_embeddings = output_image_embeddings.unsqueeze(0)
return output_image_embeddings
class UnifiedModel(pl.LightningModule):
def __init__(self, flux_text_encoder_path, flux_vae_path, flux_dit_path, flux_decoder_path, qwenvl_path):
super().__init__()
# Load models
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
flux_text_encoder_path,
flux_vae_path,
flux_dit_path
])
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
state_dict = load_state_dict(flux_decoder_path, torch_dtype=torch.bfloat16)
self.pipe.dit.load_state_dict(state_dict, strict=False)
self.adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16)
self.adapter.load_state_dict(state_dict, strict=False)
self.qwenvl = NexusGenQwenVLEncoder.from_pretrained(qwenvl_path)
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder_1.requires_grad_(False)
self.qwenvl.requires_grad_(False)
self.qwenvl.model.visual.requires_grad_(False)
self.pipe.train()
self.adapter.train()
self.qwenvl.train()
self.qwenvl.model.visual.eval()
# self.qwenvl.model.model.gradient_checkpointing = True
self.pipe.scheduler.set_timesteps(1000, training=True)
def training_step(self, batch, batch_idx):
# Data
text, image = batch["instruction"], batch["image_2"]
image_ref = batch["image_1"]
image = image.unsqueeze(0)
# Prepare input parameters
self.pipe.device = self.device
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
if image_ref is None:
instruction = f"Generate an image according to the following description: {text}"
images_ref = None
else:
instruction = f"<|vision_start|><|image_pad|><|vision_end|> {text}"
images_ref = [image_ref]
emb = self.qwenvl(instruction, images=images_ref)
emb = self.adapter(emb)
prompt_emb = self.pipe.encode_prompt("", positive=True, image_emb=emb)
noise_pred = lets_dance_flux(
self.pipe.denoising_model(),
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
image_emb=emb,
use_gradient_checkpointing=False
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
return loss
def forward(self, batch):
return self.training_step(batch, 0)
def search_for_last_checkpoint(path):
if not os.path.exists(path):
return None, 0
checkpoint_list = os.listdir(path)
checkpoint_list = [int(checkpoint.split("-")[1]) for checkpoint in checkpoint_list if checkpoint.startswith("epoch")]
if len(checkpoint_list) == 0:
return None, 0
else:
max_epoch_id = max(checkpoint_list)
return f"{path}/epoch-{max_epoch_id}/model.safetensors", max_epoch_id + 1
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="gradient_accumulation_steps",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=1000,
help="steps_per_epoch",
)
parser.add_argument(
"--output_path",
type=str,
default="./models",
help="output_path",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="learning_rate",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = UnifiedModel(
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
"models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin",
"models/DiffSynth-Studio/Nexus-Gen",
)
# dataset and data loader
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
),
],
dataset_weight=(4, 2, 1,),
steps_per_epoch=args.steps_per_epoch * accelerator.num_processes,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=1,
collate_fn=lambda x: x[0]
)
# train
pretrained_path, start_epoch_id = search_for_last_checkpoint(args.output_path)
if pretrained_path is not None:
print(f"pretrained_path: {pretrained_path}")
model.load_state_dict(load_state_dict(pretrained_path, torch_dtype=torch.bfloat16), assign=True, strict=False)
model.to(torch.bfloat16)
model.to(accelerator.device)
trainable_modules = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=args.learning_rate)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
for epoch in range(start_epoch_id, 100000):
for batch in tqdm(train_loader, desc=f"epoch-{epoch}", disable=not accelerator.is_local_main_process):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
path = args.output_path
os.makedirs(path, exist_ok=True)
accelerator.wait_for_everyone()
accelerator.save_model(model, f"{path}/epoch-{epoch}", max_shard_size="10GB", safe_serialization=True)