mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
32 Commits
lora-retri
...
v1.1.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd8884c9ef | ||
|
|
46744362de | ||
|
|
0f0cdc3afc | ||
|
|
a33c63af87 | ||
|
|
3cc9764bc9 | ||
|
|
f6c6e3c640 | ||
|
|
60a9db706e | ||
|
|
a98700feb2 | ||
|
|
5418ca781e | ||
|
|
71eee780fb | ||
|
|
4864453e0a | ||
|
|
c5a32f76c2 | ||
|
|
c4ed3d3e4b | ||
|
|
803ddcccc7 | ||
|
|
4cd51fecf2 | ||
|
|
3b0211a547 | ||
|
|
e88328d152 | ||
|
|
52896fa8dd | ||
|
|
c7035ad911 | ||
|
|
070811e517 | ||
|
|
7e010d88a5 | ||
|
|
4e43d4d461 | ||
|
|
d7efe7e539 | ||
|
|
633f789c47 | ||
|
|
88607f404e | ||
|
|
6d405b669c | ||
|
|
d0fed6ba72 | ||
|
|
64eaa0d76a | ||
|
|
54081bdcbb | ||
|
|
d8b250607a | ||
|
|
1e58e6ef82 | ||
|
|
42cb7d96bb |
18
README.md
18
README.md
@@ -13,9 +13,15 @@ Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
||||
|
||||
## Introduction
|
||||
|
||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||
Welcome to the magic world of Diffusion models!
|
||||
|
||||
Until now, DiffSynth Studio has supported the following models:
|
||||
DiffSynth consists of two open-source projects:
|
||||
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technological exploration. Targeted at academia. Provides more cutting-edge technical support and novel inference capabilities.
|
||||
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
DiffSynth-Studio is an open-source project aimed at exploring innovations in AIGC technology. We have integrated numerous open-source Diffusion models, including FLUX and Wan, among others. Through this open-source project, we hope to connect models within the open-source community and explore new technologies based on diffusion models.
|
||||
|
||||
Until now, DiffSynth-Studio has supported the following models:
|
||||
|
||||
* [Wan-Video](https://github.com/Wan-Video/Wan2.1)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
@@ -36,7 +42,11 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
- **March 25, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
|
||||
|
||||
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
|
||||
|
||||
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||
|
||||
@@ -73,7 +83,7 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
||||
- LoRA, ControlNet, and additional models will be available soon.
|
||||
|
||||
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
||||
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
||||
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
||||
|
||||
@@ -37,6 +37,7 @@ from ..models.flux_text_encoder import FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
|
||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from ..models.cog_dit import CogDiT
|
||||
@@ -58,6 +59,7 @@ from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
@@ -104,6 +106,8 @@ model_loader_configs = [
|
||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
|
||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||
@@ -117,11 +121,16 @@ model_loader_configs = [
|
||||
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
||||
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -598,6 +607,25 @@ preset_models_on_modelscope = {
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
"InfiniteYou":{
|
||||
"file_list":[
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
|
||||
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
|
||||
],
|
||||
"load_path":[
|
||||
[
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||
],
|
||||
"models/InfiniteYou/image_proj_model.bin",
|
||||
],
|
||||
},
|
||||
# ESRGAN
|
||||
"ESRGAN_x4": [
|
||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||
@@ -757,6 +785,7 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||
"InfiniteYou",
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||
"QwenPrompt",
|
||||
"OmostPrompt",
|
||||
|
||||
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
129
diffsynth/distributed/xdit_context_parallel.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from einops import rearrange
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x.to(position.dtype)
|
||||
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
padding_tensor = torch.ones(
|
||||
pad_size,
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
return padded_tensor
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
s_per_rank = x.shape[1]
|
||||
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||
|
||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
def usp_dit_forward(self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
t = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if self.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = self.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = self.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
# Context Parallel
|
||||
x = torch.chunk(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, freqs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
|
||||
x = self.head(x, t)
|
||||
|
||||
# Context Parallel
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
|
||||
def usp_attn_forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
|
||||
x = xFuserLongContextAttention()(
|
||||
None,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
)
|
||||
x = x.flatten(2)
|
||||
|
||||
del q, k, v
|
||||
torch.cuda.empty_cache()
|
||||
return self.o(x)
|
||||
@@ -318,6 +318,8 @@ class FluxControlNetStateDictConverter:
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
return state_dict_, extra_kwargs
|
||||
|
||||
@@ -41,30 +41,6 @@ class RoPEEmbedding(torch.nn.Module):
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False, dual=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.dual = dual
|
||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb, **kwargs):
|
||||
emb = self.linear(torch.nn.functional.silu(emb),**kwargs)
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
elif self.dual:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||
norm_x = self.norm(x)
|
||||
x = norm_x * (1 + scale_msa) + shift_msa
|
||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class FluxJointAttention(torch.nn.Module):
|
||||
@@ -94,17 +70,17 @@ class FluxJointAttention(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
# Part A
|
||||
qkv_a = self.a_to_qkv(hidden_states_a,**kwargs)
|
||||
qkv_a = self.a_to_qkv(hidden_states_a)
|
||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||
|
||||
# Part B
|
||||
qkv_b = self.b_to_qkv(hidden_states_b,**kwargs)
|
||||
qkv_b = self.b_to_qkv(hidden_states_b)
|
||||
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
||||
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
||||
@@ -121,25 +97,13 @@ class FluxJointAttention(torch.nn.Module):
|
||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
||||
hidden_states_a = self.a_to_out(hidden_states_a,**kwargs)
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b,**kwargs)
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
class AutoSequential(torch.nn.Sequential):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def forward(self, input, **kwargs):
|
||||
for module in self:
|
||||
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# print("##"*10)
|
||||
input = module(input, **kwargs)
|
||||
else:
|
||||
input = module(input)
|
||||
return input
|
||||
|
||||
|
||||
class FluxJointTransformerBlock(torch.nn.Module):
|
||||
@@ -156,11 +120,6 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
# self.ff_a = AutoSequential(
|
||||
# torch.nn.Linear(dim, dim*4),
|
||||
# torch.nn.GELU(approximate="tanh"),
|
||||
# torch.nn.Linear(dim*4, dim)
|
||||
# )
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
@@ -168,18 +127,14 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
# self.ff_b = AutoSequential(
|
||||
# torch.nn.Linear(dim, dim*4),
|
||||
# torch.nn.GELU(approximate="tanh"),
|
||||
# torch.nn.Linear(dim*4, dim)
|
||||
# )
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb, **kwargs)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb, **kwargs)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list, **kwargs)
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
@@ -194,6 +149,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class FluxSingleAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
||||
super().__init__()
|
||||
@@ -214,10 +170,10 @@ class FluxSingleAttention(torch.nn.Module):
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def forward(self, hidden_states, image_rotary_emb, **kwargs):
|
||||
def forward(self, hidden_states, image_rotary_emb):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv_a = self.a_to_qkv(hidden_states,**kwargs)
|
||||
qkv_a = self.a_to_qkv(hidden_states)
|
||||
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
||||
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
||||
@@ -239,8 +195,8 @@ class AdaLayerNormSingle(torch.nn.Module):
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
|
||||
def forward(self, x, emb, **kwargs):
|
||||
emb = self.linear(self.silu(emb),**kwargs)
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
@@ -270,7 +226,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
|
||||
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -287,17 +243,17 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, **kwargs):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb, **kwargs)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states, **kwargs)
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list, **kwargs)
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||
|
||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a, **kwargs)
|
||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||
hidden_states_a = residual + hidden_states_a
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
@@ -311,13 +267,14 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
||||
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, conditioning, **kwargs):
|
||||
emb = self.linear(self.silu(conditioning),**kwargs)
|
||||
def forward(self, x, conditioning):
|
||||
emb = self.linear(self.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
super().__init__()
|
||||
@@ -325,8 +282,6 @@ class FluxDiT(torch.nn.Module):
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
|
||||
# self.pooled_text_embedder = AutoSequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
@@ -473,12 +428,12 @@ class FluxDiT(torch.nn.Module):
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states,**kwargs)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = self.context_embedder(prompt_emb, **kwargs)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
@@ -491,26 +446,26 @@ class FluxDiT(torch.nn.Module):
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs,
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs)
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs,
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, **kwargs)
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning, **kwargs)
|
||||
hidden_states = self.final_proj_out(hidden_states, **kwargs)
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.final_proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
@@ -651,10 +606,6 @@ class FluxDiTStateDictConverter:
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
if "lora_B" in name:
|
||||
suffix = ".lora_B" + suffix
|
||||
if "lora_A" in name:
|
||||
suffix = ".lora_A" + suffix
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
@@ -679,73 +630,29 @@ class FluxDiTStateDictConverter:
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
|
||||
if mlp is None:
|
||||
dim = 4
|
||||
if 'lora_A' in name:
|
||||
dim = 1
|
||||
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
# print('$$'*10)
|
||||
# mlp_name = name.replace(".a_to_q.", ".proj_in_besides_attn.")
|
||||
# print(f'mlp name: {mlp_name}')
|
||||
# print(f'mlp shape: {mlp.shape}')
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
# print(f'mlp shape: {mlp.shape}')
|
||||
if 'lora_A' in name:
|
||||
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
# create zreo matrix
|
||||
d, r = state_dict_[name].shape
|
||||
# print('--'*10)
|
||||
# print(d, r)
|
||||
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||
param[:d, :r] = state_dict_.pop(name)
|
||||
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||
param[3*d:, 3*r:] = mlp
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
concat_dim = 0
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
d, r = origin.shape
|
||||
# print(d, r)
|
||||
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
@@ -811,48 +718,22 @@ class FluxDiTStateDictConverter:
|
||||
"norm.query_norm.scale": "norm_q_a.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
|
||||
|
||||
for name, param in state_dict.items():
|
||||
# match lora load
|
||||
l_name = ''
|
||||
if 'lora_A' in name :
|
||||
l_name = 'lora_A'
|
||||
if 'lora_B' in name :
|
||||
l_name = 'lora_B'
|
||||
if l_name != '':
|
||||
name = name.replace(l_name+'.', '')
|
||||
|
||||
|
||||
if name.startswith("model.diffusion_model."):
|
||||
name = name[len("model.diffusion_model."):]
|
||||
names = name.split(".")
|
||||
if name in rename_dict:
|
||||
rename = rename_dict[name]
|
||||
if name.startswith("final_layer.adaLN_modulation.1."):
|
||||
if l_name == 'lora_A':
|
||||
param = torch.concat([param[:,3072:], param[:,:3072]], dim=1)
|
||||
else:
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
if l_name != '':
|
||||
state_dict_[rename.replace('weight',l_name+'.weight')] = param
|
||||
else:
|
||||
state_dict_[rename] = param
|
||||
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "double_blocks":
|
||||
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
if l_name != '':
|
||||
state_dict_[rename.replace('weight',l_name+'.weight')] = param
|
||||
else:
|
||||
state_dict_[rename] = param
|
||||
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "single_blocks":
|
||||
if ".".join(names[2:]) in suffix_rename_dict:
|
||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
if l_name != '':
|
||||
state_dict_[rename.replace('weight',l_name+'.weight')] = param
|
||||
else:
|
||||
state_dict_[rename] = param
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
|
||||
128
diffsynth/models/flux_infiniteyou.py
Normal file
128
diffsynth/models/flux_infiniteyou.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class InfiniteYouImageProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=1280,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=20,
|
||||
num_queries=8,
|
||||
embedding_dim=512,
|
||||
output_dim=4096,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||
|
||||
|
||||
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict['image_proj']
|
||||
@@ -26,12 +26,6 @@ class LoRAFromCivitai:
|
||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||
|
||||
def convert_state_name(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
for key in state_dict:
|
||||
if ".lora_up" in key:
|
||||
return self.convert_state_name_up_down(state_dict, lora_prefix, alpha)
|
||||
return self.convert_state_name_AB(state_dict, lora_prefix, alpha)
|
||||
|
||||
|
||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
@@ -56,37 +50,13 @@ class LoRAFromCivitai:
|
||||
return state_dict_
|
||||
|
||||
|
||||
def convert_state_name_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||
for special_key in self.special_keys:
|
||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
||||
|
||||
state_dict_[target_name.replace(".weight",".lora_B.weight")] = weight_up.cpu()
|
||||
state_dict_[target_name.replace(".weight",".lora_A.weight")] = weight_down.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
@@ -97,39 +67,11 @@ class LoRAFromCivitai:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B"))
|
||||
|
||||
target_name = ".".join(keys)
|
||||
|
||||
target_name = target_name[len(lora_prefix):]
|
||||
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
def convert_state_name_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B"))
|
||||
|
||||
target_name = ".".join(keys)
|
||||
|
||||
target_name = target_name[len(lora_prefix):]
|
||||
|
||||
state_dict_[target_name.replace(".weight",".lora_B.weight")] = weight_up.cpu()
|
||||
state_dict_[target_name.replace(".weight",".lora_A.weight")] = weight_down.cpu()
|
||||
return state_dict_
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||
state_dict_model = model.state_dict()
|
||||
@@ -158,16 +100,13 @@ class LoRAFromCivitai:
|
||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
# print(f'lora_prefix: {lora_prefix}')
|
||||
state_dict_model = model.state_dict()
|
||||
for model_resource in ["diffusers", "civitai"]:
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||
# print(f'after convert_state_dict lora state_dict:{state_dict_lora_.keys()}')
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
# print(f'after converter_fn lora state_dict:{state_dict_lora_.keys()}')
|
||||
if isinstance(state_dict_lora_, tuple):
|
||||
state_dict_lora_ = state_dict_lora_[0]
|
||||
if len(state_dict_lora_) == 0:
|
||||
@@ -181,35 +120,7 @@ class LoRAFromCivitai:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_converted_lora_state_dict(self, model, state_dict_lora):
|
||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
|
||||
state_dict_model = model.state_dict()
|
||||
for model_resource in ["diffusers","civitai"]:
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_name(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == 'diffusers' \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
|
||||
if isinstance(state_dict_lora_, tuple):
|
||||
state_dict_lora_ = state_dict_lora_[0]
|
||||
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
# return state_dict_lora_
|
||||
for name in state_dict_lora_:
|
||||
if name.replace('.lora_B','').replace('.lora_A','') not in state_dict_model:
|
||||
print(f" lora's {name} is not in model.")
|
||||
break
|
||||
else:
|
||||
return state_dict_lora_
|
||||
except Exception as e:
|
||||
print(f"error {str(e)}")
|
||||
return None
|
||||
|
||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
@@ -284,85 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
"txt.mod": "txt_mod",
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||
device, torch_dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, torch_dtype = param.device, param.dtype
|
||||
break
|
||||
return device, torch_dtype
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
torch_dtype = torch.float32
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
target_name = ".".join(keys)
|
||||
if target_name.startswith("diffusion_model."):
|
||||
target_name = target_name[len("diffusion_model."):]
|
||||
if target_name not in target_state_dict:
|
||||
return {}
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||
if matched_num == len(lora_name_dict):
|
||||
return "", ""
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_device_and_dtype(self, state_dict):
|
||||
device, dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, dtype = param.device, param.dtype
|
||||
break
|
||||
computation_device = device
|
||||
computation_dtype = dtype
|
||||
if computation_device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
computation_device = torch.device("cuda")
|
||||
if computation_dtype == torch.float8_e4m3fn:
|
||||
computation_dtype = torch.float32
|
||||
return device, dtype, computation_device, computation_dtype
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
||||
weight = state_dict_model[name].to(torch.float32)
|
||||
lora_weight = state_dict_lora[name].to(
|
||||
dtype=torch.float32,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
state_dict_model[name] = (weight + lora_weight).to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
else:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
model.load_state_dict(state_dict_model)
|
||||
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_patched = weight_model + weight_lora
|
||||
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for model_class in self.supported_model_classes:
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora_) > 0:
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
@@ -466,7 +365,22 @@ class FluxLoRAConverter:
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class WanLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -62,26 +62,25 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict(file_path, torch_dtype=None):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
state_dict[k] = state_dict[k].to(device)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -183,6 +183,13 @@ class CrossAttention(nn.Module):
|
||||
return self.o(x)
|
||||
|
||||
|
||||
class GateModule(nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, gate, residual):
|
||||
return x + gate * residual
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@@ -199,16 +206,17 @@ class DiTBlock(nn.Module):
|
||||
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
||||
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
self.gate = GateModule()
|
||||
|
||||
def forward(self, x, context, t_mod, freqs):
|
||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = x + gate_msa * self.self_attn(input_x, freqs)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
x = x + gate_mlp * self.ffn(input_x)
|
||||
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||
return x
|
||||
|
||||
|
||||
@@ -485,6 +493,62 @@ class WanModelStateDictConverter:
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
44
diffsynth/models/wan_video_motion_controller.py
Normal file
44
diffsynth/models/wan_video_motion_controller.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .wan_video_dit import sinusoidal_embedding_1d
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModel(torch.nn.Module):
|
||||
def __init__(self, freq_dim=256, dim=1536):
|
||||
super().__init__()
|
||||
self.freq_dim = freq_dim
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6),
|
||||
)
|
||||
|
||||
def forward(self, motion_bucket_id):
|
||||
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
|
||||
emb = self.linear(emb)
|
||||
return emb
|
||||
|
||||
def init(self):
|
||||
state_dict = self.linear[-1].state_dict()
|
||||
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||
self.linear[-1].load_state_dict(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanMotionControllerModelDictConverter()
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers import SiglipVisionModel
|
||||
from copy import deepcopy
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import enable_vram_management, enable_auto_lora, AutoLoRALinear, AutoWrappedModule, AutoWrappedLinear
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
@@ -31,6 +31,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.controlnet: FluxMultiControlNetManager = None
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
@@ -132,15 +133,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
def enable_auto_lora(self):
|
||||
enable_auto_lora(
|
||||
self.dit,
|
||||
module_map={
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoLoRALinear,
|
||||
},
|
||||
name_prefix=''
|
||||
)
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
@@ -171,6 +163,11 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
|
||||
# InfiniteYou
|
||||
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
|
||||
if self.image_proj_model is not None:
|
||||
self.infinityou_processor = InfinitYou(device=self.device)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
@@ -356,6 +353,13 @@ class FluxImagePipeline(BasePipeline):
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
|
||||
|
||||
def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||
if self.infinityou_processor is not None and id_image is not None:
|
||||
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
else:
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -391,6 +395,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
eligen_entity_masks=None,
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=False,
|
||||
# InfiniteYou
|
||||
infinityou_id_image=None,
|
||||
infinityou_guidance=1.0,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -400,9 +407,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
lora_state_dicts=[],
|
||||
lora_alphas=[],
|
||||
lora_patcher=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
@@ -421,6 +425,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# InfiniteYou
|
||||
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
|
||||
|
||||
# Entity control
|
||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||
|
||||
@@ -442,10 +449,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_alphas = lora_alphas,
|
||||
lora_patcher=lora_patcher,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -462,10 +466,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_alphas = lora_alphas,
|
||||
lora_patcher=lora_patcher,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -485,6 +486,58 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
|
||||
|
||||
class InfinitYou:
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
from facexlib.recognition import init_recognition_model
|
||||
from insightface.app import FaceAnalysis
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
insightface_root_path = 'models/InfiniteYou/insightface'
|
||||
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
|
||||
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
|
||||
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
|
||||
self.arcface_model = init_recognition_model('arcface', device=self.device)
|
||||
|
||||
def _detect_face(self, id_image_cv2):
|
||||
face_info = self.app_640.get(id_image_cv2)
|
||||
if len(face_info) > 0:
|
||||
return face_info
|
||||
face_info = self.app_320.get(id_image_cv2)
|
||||
if len(face_info) > 0:
|
||||
return face_info
|
||||
face_info = self.app_160.get(id_image_cv2)
|
||||
return face_info
|
||||
|
||||
def extract_arcface_bgr_embedding(self, in_image, landmark):
|
||||
from insightface.utils import face_align
|
||||
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
|
||||
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||
arc_face_image = 2 * arc_face_image - 1
|
||||
arc_face_image = arc_face_image.contiguous().to(self.device)
|
||||
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
|
||||
return face_emb
|
||||
|
||||
def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
|
||||
import cv2
|
||||
if id_image is None:
|
||||
return {'id_emb': None}, controlnet_image
|
||||
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
|
||||
face_info = self._detect_face(id_image_cv2)
|
||||
if len(face_info) == 0:
|
||||
raise ValueError('No face detected in the input ID image')
|
||||
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
|
||||
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
|
||||
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
|
||||
if controlnet_image is None:
|
||||
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
|
||||
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
|
||||
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
|
||||
|
||||
|
||||
class TeaCache:
|
||||
@@ -529,6 +582,7 @@ class TeaCache:
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
def lets_dance_flux(
|
||||
dit: FluxDiT,
|
||||
controlnet: FluxMultiControlNetManager = None,
|
||||
@@ -546,11 +600,11 @@ def lets_dance_flux(
|
||||
entity_prompt_emb=None,
|
||||
entity_masks=None,
|
||||
ipadapter_kwargs_list={},
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if tiled:
|
||||
def flux_forward_fn(hl, hr, wl, wr):
|
||||
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
||||
@@ -592,6 +646,9 @@ def lets_dance_flux(
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride,
|
||||
}
|
||||
if id_emb is not None:
|
||||
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
@@ -614,11 +671,6 @@ def lets_dance_flux(
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
@@ -631,22 +683,14 @@ def lets_dance_flux(
|
||||
else:
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None), **kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
|
||||
**kwargs
|
||||
)
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
@@ -655,22 +699,14 @@ def lets_dance_flux(
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None), **kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
**kwargs
|
||||
)
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
@@ -679,8 +715,8 @@ def lets_dance_flux(
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(hidden_states)
|
||||
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning, **kwargs)
|
||||
hidden_states = dit.final_proj_out(hidden_states, **kwargs)
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = dit.final_proj_out(hidden_states)
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import types
|
||||
from ..models import ModelManager
|
||||
from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
@@ -17,6 +18,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
|
||||
|
||||
|
||||
@@ -30,9 +32,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
@@ -120,6 +124,22 @@ class WanVideoPipeline(BasePipeline):
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
if self.motion_controller is not None:
|
||||
dtype = next(iter(self.motion_controller.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.motion_controller,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
@@ -132,14 +152,24 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit = model_manager.fetch_model("wan_video_dit")
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
||||
|
||||
for block in pipe.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
||||
pipe.sp_size = get_sequence_parallel_world_size()
|
||||
pipe.use_unified_sequence_parallel = True
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -148,26 +178,51 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
def encode_image(self, image, num_frames, height, width):
|
||||
def encode_image(self, image, end_image, num_frames, height, width):
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
control_video = self.preprocess_images(control_video)
|
||||
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
return latents
|
||||
|
||||
|
||||
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if control_video is not None:
|
||||
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if clip_feature is None or y is None:
|
||||
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
|
||||
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
y = y[:, -16:]
|
||||
y = torch.concat([control_latents, y], dim=1)
|
||||
return {"clip_feature": clip_feature, "y": y}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
@@ -189,6 +244,15 @@ class WanVideoPipeline(BasePipeline):
|
||||
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return frames
|
||||
|
||||
|
||||
def prepare_unified_sequence_parallel(self):
|
||||
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||
|
||||
|
||||
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -197,7 +261,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_image=None,
|
||||
end_image=None,
|
||||
input_video=None,
|
||||
control_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
@@ -207,6 +273,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
cfg_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
sigma_shift=5.0,
|
||||
motion_bucket_id=None,
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
@@ -248,26 +315,50 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Encode image
|
||||
if input_image is not None and self.image_encoder is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.encode_image(input_image, num_frames, height, width)
|
||||
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
# ControlNet
|
||||
if control_video is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
|
||||
|
||||
# Motion Controller
|
||||
if self.motion_controller is not None and motion_bucket_id is not None:
|
||||
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
|
||||
else:
|
||||
motion_kwargs = {}
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Unified Sequence Parallel
|
||||
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
self.load_models_to_device(["dit", "motion_controller"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi)
|
||||
noise_pred_posi = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega)
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -340,16 +431,27 @@ class TeaCache:
|
||||
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
x: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if use_unified_sequence_parallel:
|
||||
import torch.distributed as dist
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
if motion_bucket_id is not None and motion_controller is not None:
|
||||
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
if dit.has_image_input:
|
||||
@@ -371,15 +473,21 @@ def model_fn_wan_video(
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
# blocks
|
||||
for block in dit.blocks:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@@ -70,56 +70,6 @@ class AutoWrappedLinear(torch.nn.Linear):
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
class AutoLoRALinear(torch.nn.Linear):
|
||||
def __init__(self, name='', in_features=1, out_features=2, bias=True, device=None, dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.name = name
|
||||
|
||||
def forward(self, x, lora_state_dicts=[], lora_alphas=[1.0,1.0], lora_patcher=None, **kwargs):
|
||||
out = torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
lora_a_name = f'{self.name}.lora_A.default.weight'
|
||||
lora_b_name = f'{self.name}.lora_B.default.weight'
|
||||
|
||||
lora_output = []
|
||||
for i, lora_state_dict in enumerate(lora_state_dicts):
|
||||
if lora_state_dict is None:
|
||||
break
|
||||
if lora_a_name in lora_state_dict and lora_b_name in lora_state_dict:
|
||||
lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device)
|
||||
lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device)
|
||||
out_lora = x @ lora_A.T @ lora_B.T
|
||||
lora_output.append(out_lora)
|
||||
if len(lora_output) > 0:
|
||||
lora_output = torch.stack(lora_output)
|
||||
out = lora_patcher(out, lora_output, self.name)
|
||||
return out
|
||||
|
||||
def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''):
|
||||
targets = list(module_map.keys())
|
||||
for name, module in model.named_children():
|
||||
if name_prefix != '':
|
||||
full_name = name_prefix + '.' + name
|
||||
else:
|
||||
full_name = name
|
||||
if isinstance(module,targets[1]):
|
||||
# print(full_name)
|
||||
# print(module)
|
||||
# ToDo: replace the linear to the AutoLoRALinear
|
||||
new_module = AutoLoRALinear(
|
||||
name=full_name,
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype)
|
||||
new_module.weight.data.copy_(module.weight.data)
|
||||
new_module.bias.data.copy_(module.bias.data)
|
||||
setattr(model, name, new_module)
|
||||
elif isinstance(module, targets[0]):
|
||||
pass
|
||||
else:
|
||||
enable_auto_lora(module, module_map, full_name)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||
for name, module in model.named_children():
|
||||
|
||||
7
examples/InfiniteYou/README.md
Normal file
7
examples/InfiniteYou/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
|
||||
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
|
||||
|
||||
|Identity Image|Generated Image|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
58
examples/InfiniteYou/infiniteyou.py
Normal file
58
examples/InfiniteYou/infiniteyou.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import importlib
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_models, ControlNetConfigUnit
|
||||
from modelscope import dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
if importlib.util.find_spec("facexlib") is None:
|
||||
raise ImportError("You are using InifiniteYou. It depends on facexlib, which is not installed. Please install it with `pip install facexlib`.")
|
||||
if importlib.util.find_spec("insightface") is None:
|
||||
raise ImportError("You are using InifiniteYou. It depends on insightface, which is not installed. Please install it with `pip install insightface`.")
|
||||
|
||||
download_models(["InfiniteYou"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_models([
|
||||
[
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
|
||||
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
|
||||
],
|
||||
"models/InfiniteYou/image_proj_model.bin",
|
||||
])
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_model_manager(
|
||||
model_manager,
|
||||
controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="none",
|
||||
model_path=[
|
||||
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors',
|
||||
'models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors'
|
||||
],
|
||||
scale=1.0
|
||||
)
|
||||
]
|
||||
)
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/infiniteyou/*")
|
||||
|
||||
prompt = "A man, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/man.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("man.jpg")
|
||||
|
||||
prompt = "A woman, portrait, cinematic"
|
||||
id_image = "data/examples/infiniteyou/woman.jpg"
|
||||
id_image = Image.open(id_image).convert('RGB')
|
||||
image = pipe(
|
||||
prompt=prompt, seed=1,
|
||||
infinityou_id_image=id_image, infinityou_guidance=1.0,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("woman.jpg")
|
||||
@@ -10,34 +10,52 @@ cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
|
||||
## Model Zoo
|
||||
|
||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
||||
|Developer|Name|Link|Scripts|
|
||||
|-|-|-|-|
|
||||
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
|
||||
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
|
||||
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
|
||||
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
|
||||
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
|
||||
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
|
||||
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
|
||||
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|
||||
## Inference
|
||||
Base model features
|
||||
|
||||
### Wan-Video-1.3B-T2V
|
||||
||Text-to-video|Image-to-video|End frame|Control|
|
||||
|-|-|-|-|-|
|
||||
|1.3B text-to-video|✅||||
|
||||
|14B text-to-video|✅||||
|
||||
|14B image-to-video 480P||✅|||
|
||||
|14B image-to-video 720P||✅|||
|
||||
|1.3B InP||✅|✅||
|
||||
|14B InP||✅|✅||
|
||||
|1.3B Control||||✅|
|
||||
|14B Control||||✅|
|
||||
|
||||
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
|
||||
Adapter model compatibility
|
||||
|
||||
Required VRAM: 6G
|
||||
||1.3B text-to-video|1.3B InP|
|
||||
|-|-|-|
|
||||
|1.3B aesthetics LoRA|✅||
|
||||
|1.3B Highres-fix LoRA|✅||
|
||||
|1.3B ExVideo LoRA|✅||
|
||||
|1.3B Speed Control adapter|✅|✅|
|
||||
|
||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||
## VRAM Usage
|
||||
|
||||
Put sunglasses on the dog.
|
||||
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
|
||||
|
||||
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
|
||||
|
||||
### Wan-Video-14B-T2V
|
||||
|
||||
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
|
||||
We present a detailed table here. The model is tested on a single A100.
|
||||
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
|
||||
|
||||
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
|
||||
|-|-|-|-|-|
|
||||
@@ -47,17 +65,46 @@ We present a detailed table here. The model is tested on a single A100.
|
||||
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
|
||||
|torch.float8_e4m3fn|0|24.0s/it|10G||
|
||||
|
||||
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
||||
|
||||
## Efficient Attention Implementation
|
||||
|
||||
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
|
||||
|
||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
|
||||
|
||||
## Acceleration
|
||||
|
||||
We support multiple acceleration solutions:
|
||||
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
|
||||
|
||||
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
|
||||
|
||||
```bash
|
||||
pip install xfuser>=0.4.3
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
||||
```
|
||||
|
||||
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
|
||||
|
||||
## Gallery
|
||||
|
||||
1.3B text-to-video.
|
||||
|
||||
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
|
||||
|
||||
Put sunglasses on the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
|
||||
14B text-to-video.
|
||||
|
||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||
|
||||
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
||||
|
||||
### Wan-Video-14B-I2V
|
||||
|
||||
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
||||
|
||||
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
|
||||
|
||||

|
||||
14B image-to-video.
|
||||
|
||||
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
|
||||
|
||||
|
||||
41
examples/wanvideo/wan_1.3b_motion_controller.py
Normal file
41
examples/wanvideo/wan_1.3b_motion_controller.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
||||
snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
"models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=0
|
||||
)
|
||||
save_video(video, "video_slow.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=1, tiled=True,
|
||||
motion_bucket_id=100
|
||||
)
|
||||
save_video(video, "video_fast.mp4", fps=15, quality=5)
|
||||
@@ -44,11 +44,28 @@ class LitModel(pl.LightningModule):
|
||||
|
||||
def configure_model(self):
|
||||
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||
plan = {
|
||||
"text_embedding.0": ColwiseParallel(),
|
||||
"text_embedding.2": RowwiseParallel(),
|
||||
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
|
||||
"text_embedding.0": ColwiseParallel(),
|
||||
"text_embedding.2": RowwiseParallel(),
|
||||
"blocks.0": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None, None, None),
|
||||
desired_input_layouts=(Replicate(), None, None, None),
|
||||
),
|
||||
"head": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None),
|
||||
desired_input_layouts=(Replicate(), None),
|
||||
use_local_output=True,
|
||||
)
|
||||
}
|
||||
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
|
||||
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||
layer_tp_plan = {
|
||||
"self_attn": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Shard(0)),
|
||||
input_layouts=(Shard(1), Replicate()),
|
||||
desired_input_layouts=(Shard(1), Shard(0)),
|
||||
),
|
||||
"self_attn.q": SequenceParallel(),
|
||||
"self_attn.k": SequenceParallel(),
|
||||
@@ -59,11 +76,11 @@ class LitModel(pl.LightningModule):
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()),
|
||||
|
||||
"cross_attn": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Replicate()),
|
||||
input_layouts=(Shard(1), Replicate()),
|
||||
desired_input_layouts=(Shard(1), Replicate()),
|
||||
),
|
||||
"cross_attn.q": SequenceParallel(),
|
||||
"cross_attn.k": SequenceParallel(),
|
||||
@@ -74,18 +91,26 @@ class LitModel(pl.LightningModule):
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"ffn.0": ColwiseParallel(),
|
||||
"ffn.2": RowwiseParallel(),
|
||||
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False),
|
||||
|
||||
"ffn.0": ColwiseParallel(input_layouts=Shard(1)),
|
||||
"ffn.2": RowwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"norm1": SequenceParallel(use_local_output=True),
|
||||
"norm2": SequenceParallel(use_local_output=True),
|
||||
"norm3": SequenceParallel(use_local_output=True),
|
||||
"gate": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
|
||||
)
|
||||
}
|
||||
parallelize_module(
|
||||
module=block,
|
||||
device_mesh=tp_mesh,
|
||||
parallelize_plan=layer_tp_plan,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_step(self, batch):
|
||||
data = batch[0]
|
||||
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
||||
@@ -94,9 +119,8 @@ class LitModel(pl.LightningModule):
|
||||
video = self.pipe(**data)
|
||||
if self.local_rank == 0:
|
||||
save_video(video, output_path, fps=15, quality=5)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
||||
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||
)
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
)
|
||||
from xfuser.core.distributed import (initialize_model_parallel,
|
||||
init_distributed_environment)
|
||||
init_distributed_environment(
|
||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=1,
|
||||
ulysses_degree=dist.get_world_size(),
|
||||
)
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=f"cuda:{dist.get_rank()}",
|
||||
use_usp=True if dist.get_world_size() > 1 else False)
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
save_video(video, "video1.mp4", fps=25, quality=5)
|
||||
42
examples/wanvideo/wan_fun_InP.py
Normal file
42
examples/wanvideo/wan_fun_InP.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download, dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Download example image
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
# Image-to-video
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
input_image=image,
|
||||
# You can input `end_image=xxx` to control the last frame of the video.
|
||||
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
40
examples/wanvideo/wan_fun_control.py
Normal file
40
examples/wanvideo/wan_fun_control.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download, dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth",
|
||||
"models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Download example video
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/control_video.mp4"
|
||||
)
|
||||
|
||||
# Control-to-video
|
||||
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
|
||||
video = pipe(
|
||||
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
control_video=control_video, height=832, width=576, num_frames=49,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
@@ -1,54 +0,0 @@
|
||||
import torch, os
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
from diffsynth.data.video import crop_and_resize
|
||||
|
||||
|
||||
class LoraDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, steps_per_epoch=1000, loras_per_item=1):
|
||||
self.base_path = base_path
|
||||
data_df = pd.read_csv(metadata_path)
|
||||
self.model_file = data_df["model_file"].tolist()
|
||||
self.image_file = data_df["image_file"].tolist()
|
||||
self.text = data_df["text"].tolist()
|
||||
self.max_resolution = 1920 * 1080
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.loras_per_item = loras_per_item
|
||||
|
||||
|
||||
def read_image(self, image_file):
|
||||
image = Image.open(image_file).convert("RGB")
|
||||
width, height = image.size
|
||||
if width * height > self.max_resolution:
|
||||
scale = (width * height / self.max_resolution) ** 0.5
|
||||
image = image.resize((int(width / scale), int(height / scale)))
|
||||
width, height = image.size
|
||||
if width % 16 != 0 or height % 16 != 0:
|
||||
image = crop_and_resize(image, height // 16 * 16, width // 16 * 16)
|
||||
image = v2.functional.to_image(image)
|
||||
image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True)
|
||||
image = v2.functional.normalize(image, [0.5], [0.5])
|
||||
return image
|
||||
|
||||
|
||||
def get_data(self, data_id):
|
||||
data = {
|
||||
"model_file": os.path.join(self.base_path, self.model_file[data_id]),
|
||||
"image": self.read_image(os.path.join(self.base_path, self.image_file[data_id])),
|
||||
"text": self.text[data_id]
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
while len(data) < self.loras_per_item:
|
||||
data_id = torch.randint(0, len(self.model_file), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.model_file) # For fixed seed.
|
||||
data.append(self.get_data(data_id))
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
@@ -1,61 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoraMerger(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
|
||||
def forward(self, base_output, lora_outputs):
|
||||
norm_base_output = self.norm_base(base_output)
|
||||
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||
gate = self.activation(
|
||||
norm_base_output * self.weight_base \
|
||||
+ norm_lora_outputs * self.weight_lora \
|
||||
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||
)
|
||||
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||
return output
|
||||
|
||||
|
||||
class LoraPatcher(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, base_output, lora_outputs, name):
|
||||
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||
@@ -1,149 +0,0 @@
|
||||
import torch
|
||||
from diffsynth import SDTextEncoder
|
||||
from diffsynth.models.sd3_text_encoder import SD3TextEncoder1StateDictConverter
|
||||
from diffsynth.models.sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class LoRALayerBlock(torch.nn.Module):
|
||||
def __init__(self, L, dim_in):
|
||||
super().__init__()
|
||||
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||
|
||||
def forward(self, lora_A, lora_B):
|
||||
out = self.x @ lora_A.T @ lora_B.T
|
||||
return out
|
||||
|
||||
|
||||
class LoRAEmbedder(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"][0]
|
||||
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim)
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
proj_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
layer_type, dim = lora_pattern["type"], lora_pattern["dim"][1]
|
||||
if layer_type not in proj_dict:
|
||||
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim, out_dim)
|
||||
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||
|
||||
self.lora_patterns = lora_patterns
|
||||
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, lora):
|
||||
lora_emb = []
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
lora_A = lora[name + ".lora_A.default.weight"]
|
||||
lora_B = lora[name + ".lora_B.default.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
lora_emb = torch.concat(lora_emb, dim=1)
|
||||
return lora_emb
|
||||
|
||||
|
||||
class TextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
return pooled_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SD3TextEncoder1StateDictConverter()
|
||||
|
||||
|
||||
class LoRAEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, max_position_embeddings=304, num_encoder_layers=2, encoder_intermediate_size=3072, L=1):
|
||||
super().__init__()
|
||||
max_position_embeddings *= L
|
||||
|
||||
# Embedder
|
||||
self.embedder = LoRAEmbedder(L=L, out_dim=embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, lora):
|
||||
embeds = self.embedder(lora) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
embeds = embeds.mean(dim=1)
|
||||
return embeds
|
||||
@@ -1,46 +0,0 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.merger import LoraPatcher
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
pipe.enable_auto_lora()
|
||||
|
||||
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
|
||||
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-3.safetensors"))
|
||||
|
||||
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
|
||||
|
||||
for seed in range(100):
|
||||
batch = dataset[0]
|
||||
num_lora = torch.randint(1, len(batch), (1,))[0]
|
||||
lora_state_dicts = [
|
||||
FluxLoRAConverter.align_to_diffsynth_format(load_lora(batch[i]["model_file"], device="cuda")) for i in range(num_lora)
|
||||
]
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_nolora.jpg")
|
||||
for i in range(num_lora):
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
lora_state_dicts=[lora_state_dicts[i]],
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_{i}.jpg")
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_merger.jpg")
|
||||
@@ -1,148 +0,0 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.retriever import TextEncoder, LoRAEncoder
|
||||
from lora.merger import LoraPatcher
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPModel
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
class LoRARetrieverTrainingModel(torch.nn.Module):
|
||||
def __init__(self, pretrained_path):
|
||||
super().__init__()
|
||||
|
||||
self.text_encoder = TextEncoder().to(torch.bfloat16)
|
||||
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
|
||||
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.text_encoder.eval()
|
||||
|
||||
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
|
||||
state_dict = load_state_dict(pretrained_path)
|
||||
self.lora_encoder.load_state_dict(state_dict)
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
text = [data["text"] for data in batch]
|
||||
input_ids = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True
|
||||
).input_ids.to(self.device)
|
||||
text_emb = self.text_encoder(input_ids)
|
||||
text_emb = text_emb / text_emb.norm()
|
||||
|
||||
lora_emb = []
|
||||
for data in batch:
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
|
||||
lora_emb.append(self.lora_encoder(lora))
|
||||
lora_emb = torch.concat(lora_emb)
|
||||
lora_emb = lora_emb / lora_emb.norm()
|
||||
|
||||
similarity = text_emb @ lora_emb.T
|
||||
print(similarity)
|
||||
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
|
||||
loss = 10 * loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
return self.lora_encoder.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_lora_list(self, lora_list):
|
||||
lora_emb = []
|
||||
for lora in tqdm(lora_list):
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(lora, device="cuda"))
|
||||
lora_emb.append(self.lora_encoder(lora))
|
||||
lora_emb = torch.concat(lora_emb)
|
||||
lora_emb = lora_emb / lora_emb.norm()
|
||||
self.lora_emb = lora_emb
|
||||
self.lora_list = lora_list
|
||||
|
||||
@torch.no_grad()
|
||||
def retrieve(self, text, k=1):
|
||||
input_ids = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True
|
||||
).input_ids.to(self.device)
|
||||
text_emb = self.text_encoder(input_ids)
|
||||
text_emb = text_emb / text_emb.norm()
|
||||
|
||||
similarity = text_emb @ self.lora_emb.T
|
||||
topk = torch.topk(similarity, k, dim=1).indices[0]
|
||||
|
||||
lora_list = []
|
||||
model_url_list = []
|
||||
for lora_id in topk:
|
||||
print(self.lora_list[lora_id])
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(self.lora_list[lora_id], device="cuda"))
|
||||
lora_list.append(lora)
|
||||
model_id = self.lora_list[lora_id].split("/")[3:5]
|
||||
model_url_list.append(f"https://www.modelscope.cn/models/{model_id[0]}/{model_id[1]}")
|
||||
return lora_list, model_url_list
|
||||
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
pipe.enable_auto_lora()
|
||||
|
||||
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
|
||||
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-9.safetensors"))
|
||||
|
||||
retriever = LoRARetrieverTrainingModel("models/lora_retriever/epoch-3.safetensors").to(dtype=torch.bfloat16, device="cuda")
|
||||
retriever.process_lora_list(list(set("data/lora/models/" + i for i in pd.read_csv("data/lora/lora_dataset_1000.csv")["model_file"])))
|
||||
|
||||
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=1)
|
||||
|
||||
text_list = []
|
||||
model_url_list = []
|
||||
for seed in range(100):
|
||||
text = dataset[0][0]["text"]
|
||||
print(text)
|
||||
loras, urls = retriever.retrieve(text, k=3)
|
||||
print(urls)
|
||||
image = pipe(
|
||||
prompt=text,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_top0.jpg")
|
||||
for i in range(2, 3):
|
||||
image = pipe(
|
||||
prompt=text,
|
||||
lora_state_dicts=loras[:i+1],
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_top{i+1}.jpg")
|
||||
|
||||
text_list.append(text)
|
||||
model_url_list.append(urls)
|
||||
df = pd.DataFrame()
|
||||
df["text"] = text_list
|
||||
df["models"] = [",".join(i) for i in model_url_list]
|
||||
df.to_csv("data/lora/lora_outputs/metadata.csv", index=False, encoding="utf-8-sig")
|
||||
@@ -1,119 +0,0 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.merger import LoraPatcher
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class LoRAMergerTrainingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu", model_id_list=["FLUX.1-dev"])
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
self.lora_patcher = LoraPatcher()
|
||||
self.pipe.enable_auto_lora()
|
||||
self.freeze_parameters()
|
||||
self.switch_to_training_mode()
|
||||
self.use_gradient_checkpointing = True
|
||||
self.state_dict_converter = FluxLoRAConverter.align_to_diffsynth_format
|
||||
self.device = "cuda"
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def switch_to_training_mode(self):
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
self.pipe.requires_grad_(False)
|
||||
self.pipe.eval()
|
||||
self.pipe.denoising_model().train()
|
||||
self.lora_patcher.requires_grad_(True)
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
# Data
|
||||
text, image = batch[0]["text"], batch[0]["image"].unsqueeze(0)
|
||||
num_lora = torch.randint(1, len(batch), (1,))[0]
|
||||
lora_state_dicts = [
|
||||
self.state_dict_converter(load_lora(batch[i]["model_file"], device=self.device)) for i in range(num_lora)
|
||||
]
|
||||
lora_alphas = None
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
noise_pred = lets_dance_flux(
|
||||
self.pipe.dit,
|
||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
lora_state_dicts=lora_state_dicts, lora_alphas=lora_alphas, lora_patcher=self.lora_patcher,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
return self.lora_patcher.parameters()
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
|
||||
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.unwrap_model(model).lora_patcher.state_dict()
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = LoRAMergerTrainingModel()
|
||||
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
|
||||
model_logger = ModelLogger("models/lora_merger")
|
||||
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for epoch_id in range(1000000):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
@@ -1,105 +0,0 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.retriever import TextEncoder, LoRAEncoder
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPModel
|
||||
|
||||
|
||||
|
||||
class LoRARetrieverTrainingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.text_encoder = TextEncoder().to(torch.bfloat16)
|
||||
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
|
||||
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.text_encoder.eval()
|
||||
|
||||
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
text = [data["text"] for data in batch]
|
||||
input_ids = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True
|
||||
).input_ids.to(self.device)
|
||||
text_emb = self.text_encoder(input_ids)
|
||||
text_emb = text_emb / text_emb.norm()
|
||||
|
||||
lora_emb = []
|
||||
for data in batch:
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
|
||||
lora_emb.append(self.lora_encoder(lora))
|
||||
lora_emb = torch.concat(lora_emb)
|
||||
lora_emb = lora_emb / lora_emb.norm()
|
||||
|
||||
similarity = text_emb @ lora_emb.T
|
||||
print(similarity)
|
||||
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
|
||||
loss = 10 * loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
return self.lora_encoder.parameters()
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
|
||||
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.unwrap_model(model).lora_encoder.state_dict()
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = LoRARetrieverTrainingModel()
|
||||
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=100, loras_per_item=32)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
|
||||
model_logger = ModelLogger("models/lora_retriever")
|
||||
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for epoch_id in range(1000000):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
print(loss)
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
@@ -1,12 +0,0 @@
|
||||
from diffsynth import load_state_dict
|
||||
import math, torch
|
||||
|
||||
|
||||
def load_lora(file_path, device):
|
||||
sd = load_state_dict(file_path, torch_dtype=torch.bfloat16, device=device)
|
||||
scale = math.sqrt(sd["lora_unet_single_blocks_9_modulation_lin.alpha"] / sd["lora_unet_single_blocks_9_modulation_lin.lora_down.weight"].shape[0])
|
||||
if scale != 1:
|
||||
sd = {i: sd[i] * scale for i in sd}
|
||||
return sd
|
||||
|
||||
|
||||
0
models/lora/Put lora files here.txt
Normal file
0
models/lora/Put lora files here.txt
Normal file
@@ -2,7 +2,6 @@ torch>=2.0.0
|
||||
torchvision
|
||||
cupy-cuda12x
|
||||
transformers==4.46.2
|
||||
controlnet-aux==0.0.7
|
||||
imageio
|
||||
imageio[ffmpeg]
|
||||
safetensors
|
||||
|
||||
Reference in New Issue
Block a user