Compare commits

..

34 Commits

Author SHA1 Message Date
Zhongjie Duan
5129d3dc52 Update setup.py 2025-04-09 15:34:02 +08:00
Zhongjie Duan
ee9bab80f2 Update requirements.txt 2025-04-09 15:33:21 +08:00
Zhongjie Duan
cd8884c9ef Update setup.py 2025-04-09 15:27:36 +08:00
Zhongjie Duan
46744362de Update requirements.txt 2025-04-09 15:26:13 +08:00
Zhongjie Duan
0f0cdc3afc Update setup.py 2025-04-09 15:15:18 +08:00
Zhongjie Duan
a33c63af87 Merge pull request #518 from modelscope/wan-fun
Wan fun
2025-04-08 19:25:12 +08:00
Artiprocher
3cc9764bc9 support more wan models 2025-04-08 19:22:53 +08:00
Artiprocher
f6c6e3c640 support more wan models 2025-04-08 17:19:54 +08:00
Artiprocher
60a9db706e support more wan models 2025-04-08 17:07:10 +08:00
lzw478614@alibaba-inc.com
a98700feb2 support wan-fun-inp generating 2025-04-06 22:55:42 +08:00
lzw478614@alibaba-inc.com
5418ca781e support load wan2.1-fun-inp-1.3B and 14B model 2025-04-03 16:37:59 +08:00
Zhongjie Duan
71eee780fb Merge pull request #511 from modelscope/version-update
Update setup.py
2025-04-02 16:35:01 +08:00
Zhongjie Duan
4864453e0a Update setup.py 2025-04-02 16:34:50 +08:00
Zhongjie Duan
c5a32f76c2 Merge pull request #509 from modelscope/wan-lora-converter
Update lora.py
2025-04-02 13:08:48 +08:00
Zhongjie Duan
c4ed3d3e4b Update lora.py 2025-04-02 13:08:16 +08:00
Zhongjie Duan
803ddcccc7 Merge pull request #505 from modelscope/infinityou
Infinityou
2025-03-31 20:21:10 +08:00
Artiprocher
4cd51fecf2 refine infinityou 2025-03-31 20:19:32 +08:00
Zhongjie Duan
3b0211a547 Merge pull request #499 from calmhawk/hotfix/tc_bug_with_usp
Fix TeaCache bug and optimize memory usage of WAN with USP feature
2025-03-31 16:24:03 +08:00
mi804
e88328d152 support infiniteyou 2025-03-31 14:29:15 +08:00
calmhawk
52896fa8dd Fix TeaCache bug with usp support integration and optimize memory usage by clearing attn cache 2025-03-30 01:13:34 +08:00
Zhongjie Duan
c7035ad911 Merge pull request #493 from modelscope/lzws-patch-1
Update wan_video.py
2025-03-26 19:48:33 +08:00
lzws
070811e517 Update wan_video.py
prompter.encode_prompt use pipe's deivce
2025-03-26 17:51:13 +08:00
Zhongjie Duan
7e010d88a5 Merge pull request #485 from modelscope/usp
support Unified Sequence Parallel
2025-03-25 19:28:42 +08:00
Artiprocher
4e43d4d461 fix usp dependency 2025-03-25 19:26:24 +08:00
Zhongjie Duan
d7efe7e539 Merge pull request #482 from modelscope/Artiprocher-patch-1
Update README.md
2025-03-25 16:44:48 +08:00
Zhongjie Duan
633f789c47 Update README.md 2025-03-25 16:44:05 +08:00
Zhongjie Duan
88607f404e Merge pull request #480 from mi804/wanx_tensor_parallel
update tensor parallel
2025-03-25 15:33:15 +08:00
mi804
6d405b669c update tensor parallel 2025-03-25 12:38:17 +08:00
ByteDance
d0fed6ba72 add usp for wanx 2025-03-25 11:51:37 +08:00
ByteDance
64eaa0d76a Merge branch 'usp' into xdit 2025-03-25 11:45:49 +08:00
Jinzhe Pan
54081bdcbb Merge pull request #1 from Eigensystem/fjr
fix some bugs
2025-03-17 17:07:07 +08:00
feifeibear
d8b250607a polish code 2025-03-17 09:04:51 +00:00
feifeibear
1e58e6ef82 fix some bugs 2025-03-17 09:00:52 +00:00
Jinzhe Pan
42cb7d96bb feat: sp for wan 2025-03-17 08:31:45 +00:00
33 changed files with 1114 additions and 1196 deletions

View File

@@ -13,9 +13,15 @@ Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
## Introduction
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
Welcome to the magic world of Diffusion models!
Until now, DiffSynth Studio has supported the following models:
DiffSynth consists of two open-source projects:
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
Until now, DiffSynth-Studio has supported the following models:
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
@@ -36,7 +42,11 @@ Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
@@ -73,7 +83,7 @@ Until now, DiffSynth Studio has supported the following models:
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and additional models will be available soon.
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).

View File

@@ -37,6 +37,7 @@ from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
@@ -58,6 +59,7 @@ from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_motion_controller import WanMotionControllerModel
model_loader_configs = [
@@ -104,6 +106,8 @@ model_loader_configs = [
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
@@ -117,11 +121,16 @@ model_loader_configs = [
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -598,6 +607,25 @@ preset_models_on_modelscope = {
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
"InfiniteYou":{
"file_list":[
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
],
"load_path":[
[
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
],
"models/InfiniteYou/image_proj_model.bin",
],
},
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -757,6 +785,7 @@ Preset_model_id: TypeAlias = Literal[
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"InstantX/FLUX.1-dev-IP-Adapter",
"InfiniteYou",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt",
"OmostPrompt",

View File

@@ -0,0 +1,129 @@
import torch
from typing import Optional
from einops import rearrange
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
s_per_rank = x.shape[1]
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)
def usp_dit_forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Context Parallel
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
x = self.head(x, t)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, (f, h, w))
return x
def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
x = xFuserLongContextAttention()(
None,
query=q,
key=k,
value=v,
)
x = x.flatten(2)
del q, k, v
torch.cuda.empty_cache()
return self.o(x)

View File

@@ -318,6 +318,8 @@ class FluxControlNetStateDictConverter:
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs

View File

@@ -41,30 +41,6 @@ class RoPEEmbedding(torch.nn.Module):
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class AdaLayerNorm(torch.nn.Module):
def __init__(self, dim, single=False, dual=False):
super().__init__()
self.single = single
self.dual = dual
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb, **kwargs):
emb = self.linear(torch.nn.functional.silu(emb),**kwargs)
if self.single:
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
x = self.norm(x) * (1 + scale) + shift
return x
elif self.dual:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
norm_x = self.norm(x)
x = norm_x * (1 + scale_msa) + shift_msa
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
x = self.norm(x) * (1 + scale_msa) + shift_msa
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class FluxJointAttention(torch.nn.Module):
@@ -94,17 +70,17 @@ class FluxJointAttention(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states_a.shape[0]
# Part A
qkv_a = self.a_to_qkv(hidden_states_a,**kwargs)
qkv_a = self.a_to_qkv(hidden_states_a)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
# Part B
qkv_b = self.b_to_qkv(hidden_states_b,**kwargs)
qkv_b = self.b_to_qkv(hidden_states_b)
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
@@ -121,25 +97,13 @@ class FluxJointAttention(torch.nn.Module):
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
if ipadapter_kwargs_list is not None:
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
hidden_states_a = self.a_to_out(hidden_states_a,**kwargs)
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
else:
hidden_states_b = self.b_to_out(hidden_states_b,**kwargs)
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class AutoSequential(torch.nn.Sequential):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input, **kwargs):
for module in self:
if isinstance(module, torch.nn.Linear):
# print("##"*10)
input = module(input, **kwargs)
else:
input = module(input)
return input
class FluxJointTransformerBlock(torch.nn.Module):
@@ -156,11 +120,6 @@ class FluxJointTransformerBlock(torch.nn.Module):
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
# self.ff_a = AutoSequential(
# torch.nn.Linear(dim, dim*4),
# torch.nn.GELU(approximate="tanh"),
# torch.nn.Linear(dim*4, dim)
# )
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_b = torch.nn.Sequential(
@@ -168,18 +127,14 @@ class FluxJointTransformerBlock(torch.nn.Module):
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
# self.ff_b = AutoSequential(
# torch.nn.Linear(dim, dim*4),
# torch.nn.GELU(approximate="tanh"),
# torch.nn.Linear(dim*4, dim)
# )
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb, **kwargs)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb, **kwargs)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list, **kwargs)
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -194,6 +149,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim):
super().__init__()
@@ -214,10 +170,10 @@ class FluxSingleAttention(torch.nn.Module):
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states, image_rotary_emb, **kwargs):
def forward(self, hidden_states, image_rotary_emb):
batch_size = hidden_states.shape[0]
qkv_a = self.a_to_qkv(hidden_states,**kwargs)
qkv_a = self.a_to_qkv(hidden_states)
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
q_a, k_a, v = qkv_a.chunk(3, dim=1)
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
@@ -239,8 +195,8 @@ class AdaLayerNormSingle(torch.nn.Module):
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb, **kwargs):
emb = self.linear(self.silu(emb),**kwargs)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
@@ -270,7 +226,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
@@ -287,17 +243,17 @@ class FluxSingleTransformerBlock(torch.nn.Module):
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb, **kwargs)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states, **kwargs)
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list, **kwargs)
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a, **kwargs)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
@@ -311,13 +267,14 @@ class AdaLayerNormContinuous(torch.nn.Module):
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
def forward(self, x, conditioning, **kwargs):
emb = self.linear(self.silu(conditioning),**kwargs)
def forward(self, x, conditioning):
emb = self.linear(self.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
return x
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False):
super().__init__()
@@ -325,8 +282,6 @@ class FluxDiT(torch.nn.Module):
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
# self.pooled_text_embedder = AutoSequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
@@ -473,12 +428,12 @@ class FluxDiT(torch.nn.Module):
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states,**kwargs)
hidden_states = self.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = self.context_embedder(prompt_emb, **kwargs)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
@@ -491,26 +446,26 @@ class FluxDiT(torch.nn.Module):
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs)
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs)
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.final_norm_out(hidden_states, conditioning, **kwargs)
hidden_states = self.final_proj_out(hidden_states, **kwargs)
hidden_states = self.final_norm_out(hidden_states, conditioning)
hidden_states = self.final_proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@@ -651,10 +606,6 @@ class FluxDiTStateDictConverter:
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
if "lora_B" in name:
suffix = ".lora_B" + suffix
if "lora_A" in name:
suffix = ".lora_A" + suffix
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
@@ -679,73 +630,29 @@ class FluxDiTStateDictConverter:
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
dim = 4
if 'lora_A' in name:
dim = 1
mlp = torch.zeros(dim * state_dict_[name].shape[0],
mlp = torch.zeros(4 * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
# print('$$'*10)
# mlp_name = name.replace(".a_to_q.", ".proj_in_besides_attn.")
# print(f'mlp name: {mlp_name}')
# print(f'mlp shape: {mlp.shape}')
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
# print(f'mlp shape: {mlp.shape}')
if 'lora_A' in name:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
elif 'lora_B' in name:
# create zreo matrix
d, r = state_dict_[name].shape
# print('--'*10)
# print(d, r)
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
param[:d, :r] = state_dict_.pop(name)
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
param[3*d:, 3*r:] = mlp
else:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
concat_dim = 0
if 'lora_A' in name:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
elif 'lora_B' in name:
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
d, r = origin.shape
# print(d, r)
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
else:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
@@ -811,48 +718,22 @@ class FluxDiTStateDictConverter:
"norm.query_norm.scale": "norm_q_a.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
# match lora load
l_name = ''
if 'lora_A' in name :
l_name = 'lora_A'
if 'lora_B' in name :
l_name = 'lora_B'
if l_name != '':
name = name.replace(l_name+'.', '')
if name.startswith("model.diffusion_model."):
name = name[len("model.diffusion_model."):]
names = name.split(".")
if name in rename_dict:
rename = rename_dict[name]
if name.startswith("final_layer.adaLN_modulation.1."):
if l_name == 'lora_A':
param = torch.concat([param[:,3072:], param[:,:3072]], dim=1)
else:
param = torch.concat([param[3072:], param[:3072]], dim=0)
if l_name != '':
state_dict_[rename.replace('weight',l_name+'.weight')] = param
else:
state_dict_[rename] = param
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[rename] = param
elif names[0] == "double_blocks":
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
if l_name != '':
state_dict_[rename.replace('weight',l_name+'.weight')] = param
else:
state_dict_[rename] = param
state_dict_[rename] = param
elif names[0] == "single_blocks":
if ".".join(names[2:]) in suffix_rename_dict:
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
if l_name != '':
state_dict_[rename.replace('weight',l_name+'.weight')] = param
else:
state_dict_[rename] = param
state_dict_[rename] = param
else:
pass
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:

View File

@@ -0,0 +1,128 @@
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class InfiniteYouImageProjector(nn.Module):
def __init__(
self,
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=8,
embedding_dim=512,
output_dim=4096,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
@staticmethod
def state_dict_converter():
return FluxInfiniteYouImageProjectorStateDictConverter()
class FluxInfiniteYouImageProjectorStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict['image_proj']

View File

@@ -26,12 +26,6 @@ class LoRAFromCivitai:
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
def convert_state_name(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
for key in state_dict:
if ".lora_up" in key:
return self.convert_state_name_up_down(state_dict, lora_prefix, alpha)
return self.convert_state_name_AB(state_dict, lora_prefix, alpha)
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
@@ -56,37 +50,13 @@ class LoRAFromCivitai:
return state_dict_
def convert_state_name_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
for special_key in self.special_keys:
target_name = target_name.replace(special_key, self.special_keys[special_key])
state_dict_[target_name.replace(".weight",".lora_B.weight")] = weight_up.cpu()
state_dict_[target_name.replace(".weight",".lora_A.weight")] = weight_down.cpu()
return state_dict_
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
@@ -97,39 +67,11 @@ class LoRAFromCivitai:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
target_name = target_name[len(lora_prefix):]
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def convert_state_name_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
keys = key.split(".")
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
target_name = target_name[len(lora_prefix):]
state_dict_[target_name.replace(".weight",".lora_B.weight")] = weight_up.cpu()
state_dict_[target_name.replace(".weight",".lora_A.weight")] = weight_down.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
state_dict_model = model.state_dict()
@@ -158,16 +100,13 @@ class LoRAFromCivitai:
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
if not isinstance(model, model_class):
continue
# print(f'lora_prefix: {lora_prefix}')
state_dict_model = model.state_dict()
for model_resource in ["diffusers", "civitai"]:
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
# print(f'after convert_state_dict lora state_dict:{state_dict_lora_.keys()}')
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_)
# print(f'after converter_fn lora state_dict:{state_dict_lora_.keys()}')
if isinstance(state_dict_lora_, tuple):
state_dict_lora_ = state_dict_lora_[0]
if len(state_dict_lora_) == 0:
@@ -181,35 +120,7 @@ class LoRAFromCivitai:
pass
return None
def get_converted_lora_state_dict(self, model, state_dict_lora):
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
for model_resource in ["diffusers","civitai"]:
try:
state_dict_lora_ = self.convert_state_name(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == 'diffusers' \
else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_)
if isinstance(state_dict_lora_, tuple):
state_dict_lora_ = state_dict_lora_[0]
if len(state_dict_lora_) == 0:
continue
# return state_dict_lora_
for name in state_dict_lora_:
if name.replace('.lora_B','').replace('.lora_A','') not in state_dict_model:
print(f" lora's {name} is not in model.")
break
else:
return state_dict_lora_
except Exception as e:
print(f"error {str(e)}")
return None
class SDLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
@@ -284,85 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
"txt.mod": "txt_mod",
}
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
if torch_dtype == torch.float8_e4m3fn:
torch_dtype = torch.float32
state_dict_ = {}
for key in state_dict:
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
target_name = ".".join(keys)
if target_name.startswith("diffusion_model."):
target_name = target_name[len("diffusion_model."):]
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def match(self, model: torch.nn.Module, state_dict_lora):
lora_name_dict = self.get_name_dict(state_dict_lora)
model_name_dict = {name: None for name, _ in model.named_parameters()}
matched_num = sum([i in model_name_dict for i in lora_name_dict])
if matched_num == len(lora_name_dict):
return "", ""
else:
return None
def fetch_device_and_dtype(self, state_dict):
device, dtype = None, None
for name, param in state_dict.items():
device, dtype = param.device, param.dtype
break
computation_device = device
computation_dtype = dtype
if computation_device == torch.device("cpu"):
if torch.cuda.is_available():
computation_device = torch.device("cuda")
if computation_dtype == torch.float8_e4m3fn:
computation_dtype = torch.float32
return device, dtype, computation_device, computation_dtype
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
if state_dict_model[name].dtype == torch.float8_e4m3fn:
weight = state_dict_model[name].to(torch.float32)
lora_weight = state_dict_lora[name].to(
dtype=torch.float32,
device=state_dict_model[name].device
)
state_dict_model[name] = (weight + lora_weight).to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
else:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model)
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
lora_name_dict = self.get_name_dict(state_dict_lora)
for name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
weight_patched = weight_model + weight_lora
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
print(f" {len(lora_name_dict)} tensors are updated.")
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) > 0:
return "", ""
except:
pass
return None
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
@@ -466,7 +365,22 @@ class FluxLoRAConverter:
else:
state_dict_[name] = param
return state_dict_
class WanLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, **kwargs):
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
return state_dict
@staticmethod
def align_to_diffsynth_format(state_dict, **kwargs):
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
return state_dict
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -62,26 +62,25 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
return state_dict
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
state_dict[k] = state_dict[k].to(device)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):

View File

@@ -183,6 +183,13 @@ class CrossAttention(nn.Module):
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__()
@@ -199,16 +206,17 @@ class DiTBlock(nn.Module):
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs):
# msa: multi-head self-attention mlp: multi-layer perceptron
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = x + gate_msa * self.self_attn(input_x, freqs)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = x + gate_mlp * self.ffn(input_x)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
@@ -485,6 +493,62 @@ class WanModelStateDictConverter:
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 1536,
"ffn_dim": 8960,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 12,
"num_layers": 30,
"eps": 1e-6
}
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 48,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else:
config = {}
return state_dict, config

View File

@@ -0,0 +1,44 @@
import torch
import torch.nn as nn
from .wan_video_dit import sinusoidal_embedding_1d
class WanMotionControllerModel(torch.nn.Module):
def __init__(self, freq_dim=256, dim=1536):
super().__init__()
self.freq_dim = freq_dim
self.linear = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6),
)
def forward(self, motion_bucket_id):
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
emb = self.linear(emb)
return emb
def init(self):
state_dict = self.linear[-1].state_dict()
state_dict = {i: state_dict[i] * 0 for i in state_dict}
self.linear[-1].load_state_dict(state_dict)
@staticmethod
def state_dict_converter():
return WanMotionControllerModelDictConverter()
class WanMotionControllerModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -13,7 +13,7 @@ from transformers import SiglipVisionModel
from copy import deepcopy
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
from ..vram_management import enable_vram_management, enable_auto_lora, AutoLoRALinear, AutoWrappedModule, AutoWrappedLinear
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class FluxImagePipeline(BasePipeline):
@@ -31,6 +31,7 @@ class FluxImagePipeline(BasePipeline):
self.controlnet: FluxMultiControlNetManager = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.infinityou_processor: InfinitYou = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
@@ -132,15 +133,6 @@ class FluxImagePipeline(BasePipeline):
)
self.enable_cpu_offload()
def enable_auto_lora(self):
enable_auto_lora(
self.dit,
module_map={
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoLoRALinear,
},
name_prefix=''
)
def denoising_model(self):
return self.dit
@@ -171,6 +163,11 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
# InfiniteYou
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if self.image_proj_model is not None:
self.infinityou_processor = InfinitYou(device=self.device)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
@@ -356,6 +353,13 @@ class FluxImagePipeline(BasePipeline):
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
if self.infinityou_processor is not None and id_image is not None:
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
else:
return {}, controlnet_image
@torch.no_grad()
@@ -391,6 +395,9 @@ class FluxImagePipeline(BasePipeline):
eligen_entity_masks=None,
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# InfiniteYou
infinityou_id_image=None,
infinityou_guidance=1.0,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -400,9 +407,6 @@ class FluxImagePipeline(BasePipeline):
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_st=None,
lora_state_dicts=[],
lora_alphas=[],
lora_patcher=None,
):
height, width = self.check_resize_height_width(height, width)
@@ -421,6 +425,9 @@ class FluxImagePipeline(BasePipeline):
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# InfiniteYou
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
# Entity control
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
@@ -442,10 +449,7 @@ class FluxImagePipeline(BasePipeline):
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
lora_state_dicts=lora_state_dicts,
lora_alphas = lora_alphas,
lora_patcher=lora_patcher,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -462,10 +466,7 @@ class FluxImagePipeline(BasePipeline):
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
lora_state_dicts=lora_state_dicts,
lora_alphas = lora_alphas,
lora_patcher=lora_patcher,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -485,6 +486,58 @@ class FluxImagePipeline(BasePipeline):
# Offload all models
self.load_models_to_device([])
return image
class InfinitYou:
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
self.device = device
self.torch_dtype = torch_dtype
insightface_root_path = 'models/InfiniteYou/insightface'
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
self.arcface_model = init_recognition_model('arcface', device=self.device)
def _detect_face(self, id_image_cv2):
face_info = self.app_640.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_320.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_160.get(id_image_cv2)
return face_info
def extract_arcface_bgr_embedding(self, in_image, landmark):
from insightface.utils import face_align
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.contiguous().to(self.device)
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
return face_emb
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
import cv2
if id_image is None:
return {'id_emb': None}, controlnet_image
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
face_info = self._detect_face(id_image_cv2)
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
if controlnet_image is None:
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
class TeaCache:
@@ -529,6 +582,7 @@ class TeaCache:
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_flux(
dit: FluxDiT,
controlnet: FluxMultiControlNetManager = None,
@@ -546,11 +600,11 @@ def lets_dance_flux(
entity_prompt_emb=None,
entity_masks=None,
ipadapter_kwargs_list={},
id_emb=None,
infinityou_guidance=None,
tea_cache: TeaCache = None,
use_gradient_checkpointing=False,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
@@ -592,6 +646,9 @@ def lets_dance_flux(
"tile_size": tile_size,
"tile_stride": tile_stride,
}
if id_emb is not None:
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
)
@@ -614,11 +671,6 @@ def lets_dance_flux(
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
# TeaCache
if tea_cache is not None:
@@ -631,22 +683,14 @@ def lets_dance_flux(
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), **kwargs,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
**kwargs
)
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -655,22 +699,14 @@ def lets_dance_flux(
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
if use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), **kwargs,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
**kwargs
)
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
@@ -679,8 +715,8 @@ def lets_dance_flux(
if tea_cache is not None:
tea_cache.store(hidden_states)
hidden_states = dit.final_norm_out(hidden_states, conditioning, **kwargs)
hidden_states = dit.final_proj_out(hidden_states, **kwargs)
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states

View File

@@ -1,3 +1,4 @@
import types
from ..models import ModelManager
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
@@ -17,6 +18,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_motion_controller import WanMotionControllerModel
@@ -30,9 +32,11 @@ class WanVideoPipeline(BasePipeline):
self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None
self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae']
self.motion_controller: WanMotionControllerModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
self.height_division_factor = 16
self.width_division_factor = 16
self.use_unified_sequence_parallel = False
def enable_vram_management(self, num_persistent_param_in_dit=None):
@@ -120,6 +124,22 @@ class WanVideoPipeline(BasePipeline):
computation_device=self.device,
),
)
if self.motion_controller is not None:
dtype = next(iter(self.motion_controller.parameters())).dtype
enable_vram_management(
self.motion_controller,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
self.enable_cpu_offload()
@@ -132,14 +152,24 @@ class WanVideoPipeline(BasePipeline):
self.dit = model_manager.fetch_model("wan_video_dit")
self.vae = model_manager.fetch_model("wan_video_vae")
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
for block in pipe.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
pipe.sp_size = get_sequence_parallel_world_size()
pipe.use_unified_sequence_parallel = True
return pipe
@@ -148,26 +178,51 @@ class WanVideoPipeline(BasePipeline):
def encode_prompt(self, prompt, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
return {"context": prompt_emb}
def encode_image(self, image, num_frames, height, width):
def encode_image(self, image, end_image, num_frames, height, width):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
if end_image is not None:
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
y = y.to(dtype=self.torch_dtype, device=self.device)
return {"clip_feature": clip_context, "y": y}
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
control_video = self.preprocess_images(control_video)
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
return latents
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
if control_video is not None:
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if clip_feature is None or y is None:
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
else:
y = y[:, -16:]
y = torch.concat([control_latents, y], dim=1)
return {"clip_feature": clip_feature, "y": y}
def tensor2video(self, frames):
@@ -189,6 +244,15 @@ class WanVideoPipeline(BasePipeline):
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
def prepare_motion_bucket_id(self, motion_bucket_id):
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
return {"motion_bucket_id": motion_bucket_id}
@torch.no_grad()
@@ -197,7 +261,9 @@ class WanVideoPipeline(BasePipeline):
prompt,
negative_prompt="",
input_image=None,
end_image=None,
input_video=None,
control_video=None,
denoising_strength=1.0,
seed=None,
rand_device="cpu",
@@ -207,6 +273,7 @@ class WanVideoPipeline(BasePipeline):
cfg_scale=5.0,
num_inference_steps=50,
sigma_shift=5.0,
motion_bucket_id=None,
tiled=True,
tile_size=(30, 52),
tile_stride=(15, 26),
@@ -248,26 +315,50 @@ class WanVideoPipeline(BasePipeline):
# Encode image
if input_image is not None and self.image_encoder is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.encode_image(input_image, num_frames, height, width)
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
else:
image_emb = {}
# ControlNet
if control_video is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
# Motion Controller
if self.motion_controller is not None and motion_bucket_id is not None:
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
else:
motion_kwargs = {}
# Extra input
extra_input = self.prepare_extra_input(latents)
# TeaCache
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
# Unified Sequence Parallel
usp_kwargs = self.prepare_unified_sequence_parallel()
# Denoise
self.load_models_to_device(["dit"])
self.load_models_to_device(["dit", "motion_controller"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi)
noise_pred_posi = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **usp_kwargs, **motion_kwargs
)
if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega)
noise_pred_nega = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller,
x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **usp_kwargs, **motion_kwargs
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -340,16 +431,27 @@ class TeaCache:
def model_fn_wan_video(
dit: WanModel,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
motion_controller: WanMotionControllerModel = None,
x: torch.Tensor = None,
timestep: torch.Tensor = None,
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None,
**kwargs,
):
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
if dit.has_image_input:
@@ -371,15 +473,21 @@ def model_fn_wan_video(
else:
tea_cache_update = False
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
if tea_cache_update:
x = tea_cache.update(x)
else:
# blocks
for block in dit.blocks:
x = block(x, context, t_mod, freqs)
if tea_cache is not None:
tea_cache.store(x)
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x

View File

@@ -70,56 +70,6 @@ class AutoWrappedLinear(torch.nn.Linear):
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
return torch.nn.functional.linear(x, weight, bias)
class AutoLoRALinear(torch.nn.Linear):
def __init__(self, name='', in_features=1, out_features=2, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias, device, dtype)
self.name = name
def forward(self, x, lora_state_dicts=[], lora_alphas=[1.0,1.0], lora_patcher=None, **kwargs):
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_a_name = f'{self.name}.lora_A.default.weight'
lora_b_name = f'{self.name}.lora_B.default.weight'
lora_output = []
for i, lora_state_dict in enumerate(lora_state_dicts):
if lora_state_dict is None:
break
if lora_a_name in lora_state_dict and lora_b_name in lora_state_dict:
lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device)
lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device)
out_lora = x @ lora_A.T @ lora_B.T
lora_output.append(out_lora)
if len(lora_output) > 0:
lora_output = torch.stack(lora_output)
out = lora_patcher(out, lora_output, self.name)
return out
def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''):
targets = list(module_map.keys())
for name, module in model.named_children():
if name_prefix != '':
full_name = name_prefix + '.' + name
else:
full_name = name
if isinstance(module,targets[1]):
# print(full_name)
# print(module)
# ToDo: replace the linear to the AutoLoRALinear
new_module = AutoLoRALinear(
name=full_name,
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype)
new_module.weight.data.copy_(module.weight.data)
new_module.bias.data.copy_(module.bias.data)
setattr(model, name, new_module)
elif isinstance(module, targets[0]):
pass
else:
enable_auto_lora(module, module_map, full_name)
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
for name, module in model.named_children():

View File

@@ -0,0 +1,7 @@
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
|Identity Image|Generated Image|
|-|-|
|![man_id](https://github.com/user-attachments/assets/bbc38a91-966e-49e8-a0d7-c5467582ad1f)|![man](https://github.com/user-attachments/assets/0decd5e1-5f65-437c-98fa-90991b6f23c1)|
|![woman_id](https://github.com/user-attachments/assets/b2894695-690e-465b-929c-61e5dc57feeb)|![woman](https://github.com/user-attachments/assets/67cc7496-c4d3-4de1-a8f1-9eb4991d95e8)|

View File

@@ -0,0 +1,58 @@
import importlib
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models, ControlNetConfigUnit
from modelscope import dataset_snapshot_download
from PIL import Image
if importlib.util.find_spec("facexlib") is None:
raise ImportError("You are using InifiniteYou. It depends on facexlib, which is not installed. Please install it with `pip install facexlib`.")
if importlib.util.find_spec("insightface") is None:
raise ImportError("You are using InifiniteYou. It depends on insightface, which is not installed. Please install it with `pip install insightface`.")
download_models(["InfiniteYou"])
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_models([
[
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
],
"models/InfiniteYou/image_proj_model.bin",
])
pipe = FluxImagePipeline.from_model_manager(
model_manager,
controlnet_config_units=[
ControlNetConfigUnit(
processor_id="none",
model_path=[
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors',
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors'
],
scale=1.0
)
]
)
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/infiniteyou/*")
prompt = "A man, portrait, cinematic"
id_image = "data/examples/infiniteyou/man.jpg"
id_image = Image.open(id_image).convert('RGB')
image = pipe(
prompt=prompt, seed=1,
infinityou_id_image=id_image, infinityou_guidance=1.0,
num_inference_steps=50, embedded_guidance=3.5,
height=1024, width=1024,
)
image.save("man.jpg")
prompt = "A woman, portrait, cinematic"
id_image = "data/examples/infiniteyou/woman.jpg"
id_image = Image.open(id_image).convert('RGB')
image = pipe(
prompt=prompt, seed=1,
infinityou_id_image=id_image, infinityou_guidance=1.0,
num_inference_steps=50, embedded_guidance=3.5,
height=1024, width=1024,
)
image.save("woman.jpg")

View File

@@ -10,34 +10,52 @@ cd DiffSynth-Studio
pip install -e .
```
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
## Model Zoo
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|Developer|Name|Link|Scripts|
|-|-|-|-|
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
## Inference
Base model features
### Wan-Video-1.3B-T2V
||Text-to-video|Image-to-video|End frame|Control|
|-|-|-|-|-|
|1.3B text-to-video|✅||||
|14B text-to-video|✅||||
|14B image-to-video 480P||✅|||
|14B image-to-video 720P||✅|||
|1.3B InP||✅|✅||
|14B InP||✅|✅||
|1.3B Control||||✅|
|14B Control||||✅|
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
Adapter model compatibility
Required VRAM: 6G
||1.3B text-to-video|1.3B InP|
|-|-|-|
|1.3B aesthetics LoRA|✅||
|1.3B Highres-fix LoRA|✅||
|1.3B ExVideo LoRA|✅||
|1.3B Speed Control adapter|✅|✅|
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
## VRAM Usage
Put sunglasses on the dog.
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
### Wan-Video-14B-T2V
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
We present a detailed table here. The model is tested on a single A100.
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|-|-|-|-|-|
@@ -47,17 +65,46 @@ We present a detailed table here. The model is tested on a single A100.
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|torch.float8_e4m3fn|0|24.0s/it|10G||
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
## Efficient Attention Implementation
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
## Acceleration
We support multiple acceleration solutions:
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
```bash
pip install xfuser>=0.4.3
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
```
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
## Gallery
1.3B text-to-video.
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
Put sunglasses on the dog.
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
14B text-to-video.
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
### Wan-Video-14B-I2V
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39)
14B image-to-video.
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75

View File

@@ -0,0 +1,41 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Text-to-video
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=0
)
save_video(video, "video_slow.mp4", fps=15, quality=5)
video = pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=1, tiled=True,
motion_bucket_id=100
)
save_video(video, "video_fast.mp4", fps=15, quality=5)

View File

@@ -44,11 +44,28 @@ class LitModel(pl.LightningModule):
def configure_model(self):
tp_mesh = self.device_mesh["tensor_parallel"]
plan = {
"text_embedding.0": ColwiseParallel(),
"text_embedding.2": RowwiseParallel(),
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
"text_embedding.0": ColwiseParallel(),
"text_embedding.2": RowwiseParallel(),
"blocks.0": PrepareModuleInput(
input_layouts=(Replicate(), None, None, None),
desired_input_layouts=(Replicate(), None, None, None),
),
"head": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Replicate(), None),
use_local_output=True,
)
}
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
for block_id, block in enumerate(self.pipe.dit.blocks):
layer_tp_plan = {
"self_attn": PrepareModuleInput(
input_layouts=(Replicate(), Replicate()),
desired_input_layouts=(Replicate(), Shard(0)),
input_layouts=(Shard(1), Replicate()),
desired_input_layouts=(Shard(1), Shard(0)),
),
"self_attn.q": SequenceParallel(),
"self_attn.k": SequenceParallel(),
@@ -59,11 +76,11 @@ class LitModel(pl.LightningModule):
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
),
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()),
"cross_attn": PrepareModuleInput(
input_layouts=(Replicate(), Replicate()),
desired_input_layouts=(Replicate(), Replicate()),
input_layouts=(Shard(1), Replicate()),
desired_input_layouts=(Shard(1), Replicate()),
),
"cross_attn.q": SequenceParallel(),
"cross_attn.k": SequenceParallel(),
@@ -74,18 +91,26 @@ class LitModel(pl.LightningModule):
input_layouts=(Shard(1), Shard(1), Shard(1)),
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
),
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
"ffn.0": ColwiseParallel(),
"ffn.2": RowwiseParallel(),
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False),
"ffn.0": ColwiseParallel(input_layouts=Shard(1)),
"ffn.2": RowwiseParallel(output_layouts=Replicate()),
"norm1": SequenceParallel(use_local_output=True),
"norm2": SequenceParallel(use_local_output=True),
"norm3": SequenceParallel(use_local_output=True),
"gate": PrepareModuleInput(
input_layouts=(Shard(1), Replicate(), Replicate()),
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
)
}
parallelize_module(
module=block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
def test_step(self, batch):
data = batch[0]
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
@@ -94,9 +119,8 @@ class LitModel(pl.LightningModule):
video = self.pipe(**data)
if self.local_rank == 0:
save_video(video, output_path, fps=15, quality=5)
if __name__ == "__main__":
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
dataloader = torch.utils.data.DataLoader(

View File

@@ -0,0 +1,58 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download
import torch.distributed as dist
# Download models
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
[
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
],
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
)
dist.init_process_group(
backend="nccl",
init_method="env://",
)
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
pipe = WanVideoPipeline.from_model_manager(model_manager,
torch_dtype=torch.bfloat16,
device=f"cuda:{dist.get_rank()}",
use_usp=True if dist.get_world_size() > 1 else False)
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Text-to-video
video = pipe(
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
seed=0, tiled=True
)
if dist.get_rank() == 0:
save_video(video, "video1.mp4", fps=25, quality=5)

View File

@@ -0,0 +1,42 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example image
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/input_image.jpg"
)
image = Image.open("data/examples/wan/input_image.jpg")
# Image-to-video
video = pipe(
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
input_image=image,
# You can input `end_image=xxx` to control the last frame of the video.
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,40 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
[
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
],
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=f"data/examples/wan/control_video.mp4"
)
# Control-to-video
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
video = pipe(
prompt="扁平风格动漫一位长发少女优雅起舞。她五官精致大眼睛明亮有神黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
control_video=control_video, height=832, width=576, num_frames=49,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)

View File

@@ -1,54 +0,0 @@
import torch, os
import pandas as pd
from PIL import Image
from torchvision.transforms import v2
from diffsynth.data.video import crop_and_resize
class LoraDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch=1000, loras_per_item=1):
self.base_path = base_path
data_df = pd.read_csv(metadata_path)
self.model_file = data_df["model_file"].tolist()
self.image_file = data_df["image_file"].tolist()
self.text = data_df["text"].tolist()
self.max_resolution = 1920 * 1080
self.steps_per_epoch = steps_per_epoch
self.loras_per_item = loras_per_item
def read_image(self, image_file):
image = Image.open(image_file).convert("RGB")
width, height = image.size
if width * height > self.max_resolution:
scale = (width * height / self.max_resolution) ** 0.5
image = image.resize((int(width / scale), int(height / scale)))
width, height = image.size
if width % 16 != 0 or height % 16 != 0:
image = crop_and_resize(image, height // 16 * 16, width // 16 * 16)
image = v2.functional.to_image(image)
image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True)
image = v2.functional.normalize(image, [0.5], [0.5])
return image
def get_data(self, data_id):
data = {
"model_file": os.path.join(self.base_path, self.model_file[data_id]),
"image": self.read_image(os.path.join(self.base_path, self.image_file[data_id])),
"text": self.text[data_id]
}
return data
def __getitem__(self, index):
data = []
while len(data) < self.loras_per_item:
data_id = torch.randint(0, len(self.model_file), (1,))[0]
data_id = (data_id + index) % len(self.model_file) # For fixed seed.
data.append(self.get_data(data_id))
return data
def __len__(self):
return self.steps_per_epoch

View File

@@ -1,61 +0,0 @@
import torch
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class LoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)

View File

@@ -1,149 +0,0 @@
import torch
from diffsynth import SDTextEncoder
from diffsynth.models.sd3_text_encoder import SD3TextEncoder1StateDictConverter
from diffsynth.models.sd_text_encoder import CLIPEncoderLayer
class LoRALayerBlock(torch.nn.Module):
def __init__(self, L, dim_in):
super().__init__()
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
def forward(self, lora_A, lora_B):
out = self.x @ lora_A.T @ lora_B.T
return out
class LoRAEmbedder(torch.nn.Module):
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"][0]
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
proj_dict = {}
for lora_pattern in lora_patterns:
layer_type, dim = lora_pattern["type"], lora_pattern["dim"][1]
if layer_type not in proj_dict:
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim, out_dim)
self.proj_dict = torch.nn.ModuleDict(proj_dict)
self.lora_patterns = lora_patterns
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
return lora_patterns
def forward(self, lora):
lora_emb = []
for lora_pattern in self.lora_patterns:
name, layer_type = lora_pattern["name"], lora_pattern["type"]
lora_A = lora[name + ".lora_A.default.weight"]
lora_B = lora[name + ".lora_B.default.weight"]
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
lora_emb.append(lora_out)
lora_emb = torch.concat(lora_emb, dim=1)
return lora_emb
class TextEncoder(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=1):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
break
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return pooled_embeds
@staticmethod
def state_dict_converter():
return SD3TextEncoder1StateDictConverter()
class LoRAEncoder(torch.nn.Module):
def __init__(self, embed_dim=768, max_position_embeddings=304, num_encoder_layers=2, encoder_intermediate_size=3072, L=1):
super().__init__()
max_position_embeddings *= L
# Embedder
self.embedder = LoRAEmbedder(L=L, out_dim=embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, lora):
embeds = self.embedder(lora) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
embeds = self.final_layer_norm(embeds)
embeds = embeds.mean(dim=1)
return embeds

View File

@@ -1,46 +0,0 @@
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
from diffsynth.models.lora import FluxLoRAConverter
from diffsynth.pipelines.flux_image import lets_dance_flux
from lora.dataset import LoraDataset
from lora.merger import LoraPatcher
from lora.utils import load_lora
import torch, os
from accelerate import Accelerator, DistributedDataParallelKwargs
from tqdm import tqdm
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.enable_auto_lora()
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-3.safetensors"))
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
for seed in range(100):
batch = dataset[0]
num_lora = torch.randint(1, len(batch), (1,))[0]
lora_state_dicts = [
FluxLoRAConverter.align_to_diffsynth_format(load_lora(batch[i]["model_file"], device="cuda")) for i in range(num_lora)
]
image = pipe(
prompt=batch[0]["text"],
seed=seed,
)
image.save(f"data/lora/lora_outputs/image_{seed}_nolora.jpg")
for i in range(num_lora):
image = pipe(
prompt=batch[0]["text"],
lora_state_dicts=[lora_state_dicts[i]],
lora_patcher=lora_patcher,
seed=seed,
)
image.save(f"data/lora/lora_outputs/image_{seed}_{i}.jpg")
image = pipe(
prompt=batch[0]["text"],
lora_state_dicts=lora_state_dicts,
lora_patcher=lora_patcher,
seed=seed,
)
image.save(f"data/lora/lora_outputs/image_{seed}_merger.jpg")

View File

@@ -1,148 +0,0 @@
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
from diffsynth.models.lora import FluxLoRAConverter
from diffsynth.pipelines.flux_image import lets_dance_flux
from lora.dataset import LoraDataset
from lora.retriever import TextEncoder, LoRAEncoder
from lora.merger import LoraPatcher
from lora.utils import load_lora
import torch, os
from accelerate import Accelerator, DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPModel
import pandas as pd
class LoRARetrieverTrainingModel(torch.nn.Module):
def __init__(self, pretrained_path):
super().__init__()
self.text_encoder = TextEncoder().to(torch.bfloat16)
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
self.text_encoder.requires_grad_(False)
self.text_encoder.eval()
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
state_dict = load_state_dict(pretrained_path)
self.lora_encoder.load_state_dict(state_dict)
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def forward(self, batch):
text = [data["text"] for data in batch]
input_ids = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=77,
truncation=True
).input_ids.to(self.device)
text_emb = self.text_encoder(input_ids)
text_emb = text_emb / text_emb.norm()
lora_emb = []
for data in batch:
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
lora_emb.append(self.lora_encoder(lora))
lora_emb = torch.concat(lora_emb)
lora_emb = lora_emb / lora_emb.norm()
similarity = text_emb @ lora_emb.T
print(similarity)
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
loss = 10 * loss.mean()
return loss
def trainable_modules(self):
return self.lora_encoder.parameters()
@torch.no_grad()
def process_lora_list(self, lora_list):
lora_emb = []
for lora in tqdm(lora_list):
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(lora, device="cuda"))
lora_emb.append(self.lora_encoder(lora))
lora_emb = torch.concat(lora_emb)
lora_emb = lora_emb / lora_emb.norm()
self.lora_emb = lora_emb
self.lora_list = lora_list
@torch.no_grad()
def retrieve(self, text, k=1):
input_ids = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=77,
truncation=True
).input_ids.to(self.device)
text_emb = self.text_encoder(input_ids)
text_emb = text_emb / text_emb.norm()
similarity = text_emb @ self.lora_emb.T
topk = torch.topk(similarity, k, dim=1).indices[0]
lora_list = []
model_url_list = []
for lora_id in topk:
print(self.lora_list[lora_id])
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(self.lora_list[lora_id], device="cuda"))
lora_list.append(lora)
model_id = self.lora_list[lora_id].split("/")[3:5]
model_url_list.append(f"https://www.modelscope.cn/models/{model_id[0]}/{model_id[1]}")
return lora_list, model_url_list
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.enable_auto_lora()
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-9.safetensors"))
retriever = LoRARetrieverTrainingModel("models/lora_retriever/epoch-3.safetensors").to(dtype=torch.bfloat16, device="cuda")
retriever.process_lora_list(list(set("data/lora/models/" + i for i in pd.read_csv("data/lora/lora_dataset_1000.csv")["model_file"])))
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=1)
text_list = []
model_url_list = []
for seed in range(100):
text = dataset[0][0]["text"]
print(text)
loras, urls = retriever.retrieve(text, k=3)
print(urls)
image = pipe(
prompt=text,
seed=seed,
)
image.save(f"data/lora/lora_outputs/image_{seed}_top0.jpg")
for i in range(2, 3):
image = pipe(
prompt=text,
lora_state_dicts=loras[:i+1],
lora_patcher=lora_patcher,
seed=seed,
)
image.save(f"data/lora/lora_outputs/image_{seed}_top{i+1}.jpg")
text_list.append(text)
model_url_list.append(urls)
df = pd.DataFrame()
df["text"] = text_list
df["models"] = [",".join(i) for i in model_url_list]
df.to_csv("data/lora/lora_outputs/metadata.csv", index=False, encoding="utf-8-sig")

View File

@@ -1,119 +0,0 @@
from diffsynth import FluxImagePipeline, ModelManager
from diffsynth.models.lora import FluxLoRAConverter
from diffsynth.pipelines.flux_image import lets_dance_flux
from lora.dataset import LoraDataset
from lora.merger import LoraPatcher
from lora.utils import load_lora
import torch, os
from accelerate import Accelerator, DistributedDataParallelKwargs
from tqdm import tqdm
class LoRAMergerTrainingModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu", model_id_list=["FLUX.1-dev"])
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.lora_patcher = LoraPatcher()
self.pipe.enable_auto_lora()
self.freeze_parameters()
self.switch_to_training_mode()
self.use_gradient_checkpointing = True
self.state_dict_converter = FluxLoRAConverter.align_to_diffsynth_format
self.device = "cuda"
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def switch_to_training_mode(self):
self.pipe.scheduler.set_timesteps(1000, training=True)
def freeze_parameters(self):
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
self.lora_patcher.requires_grad_(True)
def forward(self, batch):
# Data
text, image = batch[0]["text"], batch[0]["image"].unsqueeze(0)
num_lora = torch.randint(1, len(batch), (1,))[0]
lora_state_dicts = [
self.state_dict_converter(load_lora(batch[i]["model_file"], device=self.device)) for i in range(num_lora)
]
lora_alphas = None
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.dit,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
lora_state_dicts=lora_state_dicts, lora_alphas=lora_alphas, lora_patcher=self.lora_patcher,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
return loss
def trainable_modules(self):
return self.lora_patcher.parameters()
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).lora_patcher.state_dict()
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
if __name__ == '__main__':
model = LoRAMergerTrainingModel()
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
model_logger = ModelLogger("models/lora_merger")
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for epoch_id in range(1000000):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_epoch_end(accelerator, model, epoch_id)

View File

@@ -1,105 +0,0 @@
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
from diffsynth.models.lora import FluxLoRAConverter
from diffsynth.pipelines.flux_image import lets_dance_flux
from lora.dataset import LoraDataset
from lora.retriever import TextEncoder, LoRAEncoder
from lora.utils import load_lora
import torch, os
from accelerate import Accelerator, DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPModel
class LoRARetrieverTrainingModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = TextEncoder().to(torch.bfloat16)
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
self.text_encoder.requires_grad_(False)
self.text_encoder.eval()
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def forward(self, batch):
text = [data["text"] for data in batch]
input_ids = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=77,
truncation=True
).input_ids.to(self.device)
text_emb = self.text_encoder(input_ids)
text_emb = text_emb / text_emb.norm()
lora_emb = []
for data in batch:
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
lora_emb.append(self.lora_encoder(lora))
lora_emb = torch.concat(lora_emb)
lora_emb = lora_emb / lora_emb.norm()
similarity = text_emb @ lora_emb.T
print(similarity)
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
loss = 10 * loss.mean()
return loss
def trainable_modules(self):
return self.lora_encoder.parameters()
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).lora_encoder.state_dict()
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
if __name__ == '__main__':
model = LoRARetrieverTrainingModel()
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=100, loras_per_item=32)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
model_logger = ModelLogger("models/lora_retriever")
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for epoch_id in range(1000000):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
print(loss)
model_logger.on_epoch_end(accelerator, model, epoch_id)

View File

@@ -1,12 +0,0 @@
from diffsynth import load_state_dict
import math, torch
def load_lora(file_path, device):
sd = load_state_dict(file_path, torch_dtype=torch.bfloat16, device=device)
scale = math.sqrt(sd["lora_unet_single_blocks_9_modulation_lin.alpha"] / sd["lora_unet_single_blocks_9_modulation_lin.lora_down.weight"].shape[0])
if scale != 1:
sd = {i: sd[i] * scale for i in sd}
return sd

View File

View File

@@ -1,3 +1,4 @@
wheel==0.44.0
torch>=2.0.0
torchvision
cupy-cuda12x

View File

@@ -14,7 +14,7 @@ else:
setup(
name="diffsynth",
version="1.1.2",
version="1.1.6",
description="Enjoy the magic of Diffusion models!",
author="Artiprocher",
packages=find_packages(),