mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
67 Commits
vram-bugfi
...
doc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b13963382c | ||
|
|
e11cf9e318 | ||
|
|
4f93be2f5a | ||
|
|
c9c6be2201 | ||
|
|
2b07df1c7a | ||
|
|
31161317e3 | ||
|
|
787813ab98 | ||
|
|
9fef3ee281 | ||
|
|
b205513041 | ||
|
|
97dd398f17 | ||
|
|
85ac23c0c3 | ||
|
|
b4073d2643 | ||
|
|
9583f16a43 | ||
|
|
b0633ac8bb | ||
|
|
9166a6742c | ||
|
|
10cfa6d711 | ||
|
|
b78ffbe09e | ||
|
|
64af33fe33 | ||
|
|
1180f450ca | ||
|
|
99726e02de | ||
|
|
e0c09ed53d | ||
|
|
250ebf5c72 | ||
|
|
47a2f86f7b | ||
|
|
e2d9710d86 | ||
|
|
384ea0dc69 | ||
|
|
e0ef3eea60 | ||
|
|
ac67acd235 | ||
|
|
fe68a3d1bb | ||
|
|
deff4512f7 | ||
|
|
29efb1c828 | ||
|
|
e833a31909 | ||
|
|
b626d2aad7 | ||
|
|
0bc89f973e | ||
|
|
3eeaa1cd32 | ||
|
|
4b25495921 | ||
|
|
ac2b187b9f | ||
|
|
eece711313 | ||
|
|
4e1cea64ad | ||
|
|
1a2ce26d37 | ||
|
|
b17a0297a2 | ||
|
|
9af2d08a33 | ||
|
|
b2df73d033 | ||
|
|
3514eba956 | ||
|
|
fb0e5d1f38 | ||
|
|
b43cc35dd9 | ||
|
|
34ca18a217 | ||
|
|
550d780cd6 | ||
|
|
ded2882e87 | ||
|
|
f6e676cdf9 | ||
|
|
157ba2e426 | ||
|
|
1a004ffe81 | ||
|
|
70c4ff4121 | ||
|
|
883d26abb4 | ||
|
|
105d4ffbc2 | ||
|
|
24b78148b8 | ||
|
|
793062e141 | ||
|
|
98f07f2435 | ||
|
|
ca4b9c8bf4 | ||
|
|
a2ab597eb0 | ||
|
|
950fb486d6 | ||
|
|
28b4a5313e | ||
|
|
d9d37568a7 | ||
|
|
55f1a10255 | ||
|
|
677ecbf1d2 | ||
|
|
5a06ac5e31 | ||
|
|
41f58e2d41 | ||
|
|
7f6e35fe35 |
BIN
.github/workflows/logo.gif
vendored
BIN
.github/workflows/logo.gif
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 146 KiB |
4
.github/workflows/publish.yaml
vendored
4
.github/workflows/publish.yaml
vendored
@@ -20,9 +20,9 @@ jobs:
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install wheel
|
||||
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||
run: pip install wheel && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python -m build
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
|
||||
175
.gitignore
vendored
175
.gitignore
vendored
@@ -1,175 +0,0 @@
|
||||
/data
|
||||
/models
|
||||
/scripts
|
||||
/diffusers
|
||||
*.pkl
|
||||
*.safetensors
|
||||
*.pth
|
||||
*.ckpt
|
||||
*.pt
|
||||
*.bin
|
||||
*.DS_Store
|
||||
*.msc
|
||||
*.mv
|
||||
log*.txt
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
1012
README_zh.md
1012
README_zh.md
File diff suppressed because it is too large
Load Diff
252
apps/gradio/DiffSynth_Studio.py
Normal file
252
apps/gradio/DiffSynth_Studio.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import gradio as gr
|
||||
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
||||
import os, torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
config = {
|
||||
"model_config": {
|
||||
"Stable Diffusion": {
|
||||
"model_folder": "models/stable_diffusion",
|
||||
"pipeline_class": SDImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
}
|
||||
},
|
||||
"Stable Diffusion XL": {
|
||||
"model_folder": "models/stable_diffusion_xl",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"Stable Diffusion 3": {
|
||||
"model_folder": "models/stable_diffusion_3",
|
||||
"pipeline_class": SD3ImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"Stable Diffusion XL Turbo": {
|
||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"default_parameters": {
|
||||
"negative_prompt": "",
|
||||
"cfg_scale": 1.0,
|
||||
"num_inference_steps": 1,
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
}
|
||||
},
|
||||
"Kolors": {
|
||||
"model_folder": "models/kolors",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"HunyuanDiT": {
|
||||
"model_folder": "models/HunyuanDiT",
|
||||
"pipeline_class": HunyuanDiTImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
},
|
||||
"FLUX": {
|
||||
"model_folder": "models/FLUX",
|
||||
"pipeline_class": FluxImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 1.0,
|
||||
}
|
||||
}
|
||||
},
|
||||
"max_num_painter_layers": 8,
|
||||
"max_num_model_cache": 1,
|
||||
}
|
||||
|
||||
|
||||
def load_model_list(model_type):
|
||||
if model_type is None:
|
||||
return []
|
||||
folder = config["model_config"][model_type]["model_folder"]
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model(model_type, model_path):
|
||||
global model_dict
|
||||
model_key = f"{model_type}:{model_path}"
|
||||
if model_key in model_dict:
|
||||
return model_dict[model_key]
|
||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||
model_manager = ModelManager()
|
||||
if model_type == "HunyuanDiT":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||
])
|
||||
elif model_type == "Kolors":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "text_encoder"),
|
||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||
])
|
||||
elif model_type == "FLUX":
|
||||
model_manager.torch_dtype = torch.bfloat16
|
||||
file_list = [
|
||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
||||
os.path.join(model_path, "text_encoder_2"),
|
||||
]
|
||||
for file_name in os.listdir(model_path):
|
||||
if file_name.endswith(".safetensors"):
|
||||
file_list.append(os.path.join(model_path, file_name))
|
||||
model_manager.load_models(file_list)
|
||||
else:
|
||||
model_manager.load_model(model_path)
|
||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
||||
key = next(iter(model_dict.keys()))
|
||||
model_manager_to_release, _ = model_dict[key]
|
||||
model_manager_to_release.to("cpu")
|
||||
del model_dict[key]
|
||||
torch.cuda.empty_cache()
|
||||
model_dict[model_key] = model_manager, pipe
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
model_dict = {}
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# DiffSynth-Studio Painter")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=382, min_width=100):
|
||||
|
||||
with gr.Accordion(label="Model"):
|
||||
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
|
||||
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
||||
|
||||
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
||||
def model_type_to_model_path(model_type):
|
||||
return gr.Dropdown(choices=load_model_list(model_type))
|
||||
|
||||
with gr.Accordion(label="Prompt"):
|
||||
prompt = gr.Textbox(label="Prompt", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
|
||||
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
|
||||
|
||||
with gr.Accordion(label="Image"):
|
||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
|
||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||
with gr.Column():
|
||||
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
|
||||
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
|
||||
|
||||
@gr.on(
|
||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
||||
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
|
||||
triggers=model_path.change
|
||||
)
|
||||
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
||||
load_model(model_type, model_path)
|
||||
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
||||
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
||||
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
||||
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
||||
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
||||
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
||||
|
||||
|
||||
with gr.Column(scale=618, min_width=100):
|
||||
with gr.Accordion(label="Painter"):
|
||||
enable_local_prompt_list = []
|
||||
local_prompt_list = []
|
||||
mask_scale_list = []
|
||||
canvas_list = []
|
||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
||||
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
|
||||
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
|
||||
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
|
||||
label="Painter", key=f"canvas_{painter_layer_id}")
|
||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
|
||||
def resize_canvas(height, width, canvas):
|
||||
h, w = canvas["background"].shape[:2]
|
||||
if h != height or width != w:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
else:
|
||||
return canvas
|
||||
|
||||
enable_local_prompt_list.append(enable_local_prompt)
|
||||
local_prompt_list.append(local_prompt)
|
||||
mask_scale_list.append(mask_scale)
|
||||
canvas_list.append(canvas)
|
||||
with gr.Accordion(label="Results"):
|
||||
run_button = gr.Button(value="Generate", variant="primary")
|
||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||
with gr.Column():
|
||||
output_to_input_button = gr.Button(value="Set as input image")
|
||||
painter_background = gr.State(None)
|
||||
input_background = gr.State(None)
|
||||
@gr.on(
|
||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
||||
outputs=[output_image],
|
||||
triggers=run_button.click
|
||||
)
|
||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
||||
_, pipe = load_model(model_type, model_path)
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"progress_bar_cmd": progress.tqdm,
|
||||
}
|
||||
if isinstance(pipe, FluxImagePipeline):
|
||||
input_params["embedded_guidance"] = embedded_guidance
|
||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
||||
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
||||
)
|
||||
local_prompts, masks, mask_scales = [], [], []
|
||||
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
|
||||
):
|
||||
if enable_local_prompt:
|
||||
local_prompts.append(local_prompt)
|
||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||
mask_scales.append(mask_scale)
|
||||
input_params.update({
|
||||
"local_prompts": local_prompts,
|
||||
"masks": masks,
|
||||
"mask_scales": mask_scales,
|
||||
})
|
||||
torch.manual_seed(seed)
|
||||
image = pipe(**input_params)
|
||||
return image
|
||||
|
||||
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||
def send_output_to_painter_background(output_image, *canvas_list):
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = output_image.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
app.launch()
|
||||
15
apps/streamlit/DiffSynth_Studio.py
Normal file
15
apps/streamlit/DiffSynth_Studio.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Set web page format
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
# Diasble virtual VRAM on windows system
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
||||
|
||||
|
||||
st.markdown("""
|
||||
# DiffSynth Studio
|
||||
|
||||
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
|
||||
|
||||
Welcome to DiffSynth Studio.
|
||||
""")
|
||||
362
apps/streamlit/pages/1_Image_Creator.py
Normal file
362
apps/streamlit/pages/1_Image_Creator.py
Normal file
@@ -0,0 +1,362 @@
|
||||
import torch, os, io, json, time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
|
||||
from diffsynth.data.video import crop_and_resize
|
||||
|
||||
|
||||
config = {
|
||||
"Stable Diffusion": {
|
||||
"model_folder": "models/stable_diffusion",
|
||||
"pipeline_class": SDImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"Stable Diffusion XL": {
|
||||
"model_folder": "models/stable_diffusion_xl",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"Stable Diffusion 3": {
|
||||
"model_folder": "models/stable_diffusion_3",
|
||||
"pipeline_class": SD3ImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"Stable Diffusion XL Turbo": {
|
||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"fixed_parameters": {
|
||||
"negative_prompt": "",
|
||||
"cfg_scale": 1.0,
|
||||
"num_inference_steps": 1,
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
}
|
||||
},
|
||||
"Kolors": {
|
||||
"model_folder": "models/kolors",
|
||||
"pipeline_class": SDXLImagePipeline,
|
||||
"fixed_parameters": {}
|
||||
},
|
||||
"HunyuanDiT": {
|
||||
"model_folder": "models/HunyuanDiT",
|
||||
"pipeline_class": HunyuanDiTImagePipeline,
|
||||
"fixed_parameters": {
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
},
|
||||
"FLUX": {
|
||||
"model_folder": "models/FLUX",
|
||||
"pipeline_class": FluxImagePipeline,
|
||||
"fixed_parameters": {
|
||||
"cfg_scale": 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_model_list(model_type):
|
||||
folder = config[model_type]["model_folder"]
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
|
||||
def release_model():
|
||||
if "model_manager" in st.session_state:
|
||||
st.session_state["model_manager"].to("cpu")
|
||||
del st.session_state["loaded_model_path"]
|
||||
del st.session_state["model_manager"]
|
||||
del st.session_state["pipeline"]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_model(model_type, model_path):
|
||||
model_manager = ModelManager()
|
||||
if model_type == "HunyuanDiT":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
|
||||
os.path.join(model_path, "mt5/pytorch_model.bin"),
|
||||
os.path.join(model_path, "model/pytorch_model_ema.pt"),
|
||||
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||
])
|
||||
elif model_type == "Kolors":
|
||||
model_manager.load_models([
|
||||
os.path.join(model_path, "text_encoder"),
|
||||
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||
])
|
||||
elif model_type == "FLUX":
|
||||
model_manager.torch_dtype = torch.bfloat16
|
||||
file_list = [
|
||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
||||
os.path.join(model_path, "text_encoder_2"),
|
||||
]
|
||||
for file_name in os.listdir(model_path):
|
||||
if file_name.endswith(".safetensors"):
|
||||
file_list.append(os.path.join(model_path, file_name))
|
||||
model_manager.load_models(file_list)
|
||||
else:
|
||||
model_manager.load_model(model_path)
|
||||
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
st.session_state.loaded_model_path = model_path
|
||||
st.session_state.model_manager = model_manager
|
||||
st.session_state.pipeline = pipeline
|
||||
return model_manager, pipeline
|
||||
|
||||
|
||||
def use_output_image_as_input(update=True):
|
||||
# Search for input image
|
||||
output_image_id = 0
|
||||
selected_output_image = None
|
||||
while True:
|
||||
if f"use_output_as_input_{output_image_id}" not in st.session_state:
|
||||
break
|
||||
if st.session_state[f"use_output_as_input_{output_image_id}"]:
|
||||
selected_output_image = st.session_state["output_images"][output_image_id]
|
||||
break
|
||||
output_image_id += 1
|
||||
if update and selected_output_image is not None:
|
||||
st.session_state["input_image"] = selected_output_image
|
||||
return selected_output_image is not None
|
||||
|
||||
|
||||
def apply_stroke_to_image(stroke_image, image):
|
||||
image = np.array(image.convert("RGB")).astype(np.float32)
|
||||
height, width, _ = image.shape
|
||||
|
||||
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
|
||||
weight = stroke_image[:, :, -1:] / 255
|
||||
stroke_image = stroke_image[:, :, :-1]
|
||||
|
||||
image = stroke_image * weight + image * (1 - weight)
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def image2bits(image):
|
||||
image_byte = io.BytesIO()
|
||||
image.save(image_byte, format="PNG")
|
||||
image_byte = image_byte.getvalue()
|
||||
return image_byte
|
||||
|
||||
|
||||
def show_output_image(image):
|
||||
st.image(image, use_column_width="always")
|
||||
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
|
||||
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
|
||||
|
||||
|
||||
column_input, column_output = st.columns(2)
|
||||
with st.sidebar:
|
||||
# Select a model
|
||||
with st.expander("Model", expanded=True):
|
||||
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
|
||||
fixed_parameters = config[model_type]["fixed_parameters"]
|
||||
model_path_list = ["None"] + load_model_list(model_type)
|
||||
model_path = st.selectbox("Model path", model_path_list)
|
||||
|
||||
# Load the model
|
||||
if model_path == "None":
|
||||
# No models are selected. Release VRAM.
|
||||
st.markdown("No models are selected.")
|
||||
release_model()
|
||||
else:
|
||||
# A model is selected.
|
||||
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
||||
if st.session_state.get("loaded_model_path", "") != model_path:
|
||||
# The loaded model is not the selected model. Reload it.
|
||||
st.markdown(f"Loading model at {model_path}.")
|
||||
st.markdown("Please wait a moment...")
|
||||
release_model()
|
||||
model_manager, pipeline = load_model(model_type, model_path)
|
||||
st.markdown("Done.")
|
||||
else:
|
||||
# The loaded model is not the selected model. Fetch it from `st.session_state`.
|
||||
st.markdown(f"Loading model at {model_path}.")
|
||||
st.markdown("Please wait a moment...")
|
||||
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
|
||||
st.markdown("Done.")
|
||||
|
||||
# Show parameters
|
||||
with st.expander("Prompt", expanded=True):
|
||||
prompt = st.text_area("Positive prompt")
|
||||
if "negative_prompt" in fixed_parameters:
|
||||
negative_prompt = fixed_parameters["negative_prompt"]
|
||||
else:
|
||||
negative_prompt = st.text_area("Negative prompt")
|
||||
if "cfg_scale" in fixed_parameters:
|
||||
cfg_scale = fixed_parameters["cfg_scale"]
|
||||
else:
|
||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
|
||||
with st.expander("Image", expanded=True):
|
||||
if "num_inference_steps" in fixed_parameters:
|
||||
num_inference_steps = fixed_parameters["num_inference_steps"]
|
||||
else:
|
||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
|
||||
if "height" in fixed_parameters:
|
||||
height = fixed_parameters["height"]
|
||||
else:
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
if "width" in fixed_parameters:
|
||||
width = fixed_parameters["width"]
|
||||
else:
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
|
||||
num_images = st.number_input("Number of images", value=2)
|
||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
||||
if use_fixed_seed:
|
||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
||||
|
||||
# Other fixed parameters
|
||||
denoising_strength = 1.0
|
||||
repetition = 1
|
||||
|
||||
|
||||
# Show input image
|
||||
with column_input:
|
||||
with st.expander("Input image (Optional)", expanded=True):
|
||||
with st.container(border=True):
|
||||
column_white_board, column_upload_image = st.columns([1, 2])
|
||||
with column_white_board:
|
||||
create_white_board = st.button("Create white board")
|
||||
delete_input_image = st.button("Delete input image")
|
||||
with column_upload_image:
|
||||
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
|
||||
|
||||
if upload_image is not None:
|
||||
st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
|
||||
elif create_white_board:
|
||||
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
|
||||
else:
|
||||
use_output_image_as_input()
|
||||
|
||||
if delete_input_image and "input_image" in st.session_state:
|
||||
del st.session_state.input_image
|
||||
if delete_input_image and "upload_image" in st.session_state:
|
||||
del st.session_state.upload_image
|
||||
|
||||
input_image = st.session_state.get("input_image", None)
|
||||
if input_image is not None:
|
||||
with st.container(border=True):
|
||||
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
|
||||
with column_drawing_mode:
|
||||
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
|
||||
with column_color_1:
|
||||
stroke_color = st.color_picker("Stroke color")
|
||||
with column_color_2:
|
||||
fill_color = st.color_picker("Fill color")
|
||||
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
|
||||
with st.container(border=True):
|
||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
|
||||
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
|
||||
with st.container(border=True):
|
||||
input_width, input_height = input_image.size
|
||||
canvas_result = st_canvas(
|
||||
fill_color=fill_color,
|
||||
stroke_width=stroke_width,
|
||||
stroke_color=stroke_color,
|
||||
background_color="rgba(255, 255, 255, 0)",
|
||||
background_image=input_image,
|
||||
update_streamlit=True,
|
||||
height=int(512 / input_width * input_height),
|
||||
width=512,
|
||||
drawing_mode=drawing_mode,
|
||||
key="canvas"
|
||||
)
|
||||
|
||||
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
|
||||
local_prompts, masks, mask_scales = [], [], []
|
||||
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
|
||||
painter_layers_json_data = []
|
||||
for painter_tab_id in range(num_painter_layer):
|
||||
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
|
||||
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
|
||||
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
|
||||
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
|
||||
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
|
||||
canvas_result_local = st_canvas(
|
||||
fill_color="#000000",
|
||||
stroke_width=stroke_width,
|
||||
stroke_color="#000000",
|
||||
background_color="rgba(255, 255, 255, 0)",
|
||||
background_image=white_board,
|
||||
update_streamlit=True,
|
||||
height=512,
|
||||
width=512,
|
||||
drawing_mode="freedraw",
|
||||
key=f"canvas_{painter_tab_id}"
|
||||
)
|
||||
if canvas_result_local.json_data is not None:
|
||||
painter_layers_json_data.append(canvas_result_local.json_data.copy())
|
||||
painter_layers_json_data[-1]["prompt"] = local_prompt
|
||||
if enable_local_prompt:
|
||||
local_prompts.append(local_prompt)
|
||||
if canvas_result_local.image_data is not None:
|
||||
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
|
||||
else:
|
||||
mask = white_board
|
||||
mask = Image.fromarray(255 - np.array(mask))
|
||||
masks.append(mask)
|
||||
mask_scales.append(mask_scale)
|
||||
save_painter_layers = st.button("Save painter layers")
|
||||
if save_painter_layers:
|
||||
os.makedirs("data/painter_layers", exist_ok=True)
|
||||
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
|
||||
with open(json_file_path, "w") as f:
|
||||
json.dump(painter_layers_json_data, f, indent=4)
|
||||
st.markdown(f"Painter layers are saved in {json_file_path}.")
|
||||
|
||||
|
||||
with column_output:
|
||||
run_button = st.button("Generate image", type="primary")
|
||||
auto_update = st.checkbox("Auto update", value=False)
|
||||
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
|
||||
image_columns = st.columns(num_image_columns)
|
||||
|
||||
# Run
|
||||
if (run_button or auto_update) and model_path != "None":
|
||||
|
||||
if input_image is not None:
|
||||
input_image = input_image.resize((width, height))
|
||||
if canvas_result.image_data is not None:
|
||||
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
|
||||
|
||||
output_images = []
|
||||
for image_id in range(num_images * repetition):
|
||||
if use_fixed_seed:
|
||||
torch.manual_seed(seed + image_id)
|
||||
else:
|
||||
torch.manual_seed(np.random.randint(0, 10**9))
|
||||
if image_id >= num_images:
|
||||
input_image = output_images[image_id - num_images]
|
||||
with image_columns[image_id % num_image_columns]:
|
||||
progress_bar_st = st.progress(0.0)
|
||||
image = pipeline(
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
|
||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
||||
height=height, width=width,
|
||||
input_image=input_image, denoising_strength=denoising_strength,
|
||||
progress_bar_st=progress_bar_st
|
||||
)
|
||||
output_images.append(image)
|
||||
progress_bar_st.progress(1.0)
|
||||
show_output_image(image)
|
||||
st.session_state["output_images"] = output_images
|
||||
|
||||
elif "output_images" in st.session_state:
|
||||
for image_id in range(len(st.session_state.output_images)):
|
||||
with image_columns[image_id % num_image_columns]:
|
||||
image = st.session_state.output_images[image_id]
|
||||
progress_bar = st.progress(1.0)
|
||||
show_output_image(image)
|
||||
if "upload_image" in st.session_state and use_output_image_as_input(update=False):
|
||||
st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")
|
||||
197
apps/streamlit/pages/2_Video_Creator.py
Normal file
197
apps/streamlit/pages/2_Video_Creator.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from diffsynth import SDVideoPipelineRunner
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_model_list(folder):
|
||||
file_list = os.listdir(folder)
|
||||
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
|
||||
def match_processor_id(model_name, supported_processor_id_list):
|
||||
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
|
||||
for processor_id in sorted_processor_id:
|
||||
if processor_id in model_name:
|
||||
return supported_processor_id_list.index(processor_id) + 1
|
||||
return 0
|
||||
|
||||
|
||||
config = {
|
||||
"models": {
|
||||
"model_list": [],
|
||||
"textual_inversion_folder": "models/textual_inversion",
|
||||
"device": "cuda",
|
||||
"lora_alphas": [],
|
||||
"controlnet_units": []
|
||||
},
|
||||
"data": {
|
||||
"input_frames": None,
|
||||
"controlnet_frames": [],
|
||||
"output_folder": "output",
|
||||
"fps": 60
|
||||
},
|
||||
"pipeline": {
|
||||
"seed": 0,
|
||||
"pipeline_inputs": {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with st.expander("Model", expanded=True):
|
||||
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
|
||||
if stable_diffusion_ckpt != "None":
|
||||
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
|
||||
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
|
||||
if animatediff_ckpt != "None":
|
||||
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
|
||||
column_lora, column_lora_alpha = st.columns([2, 1])
|
||||
with column_lora:
|
||||
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
|
||||
with column_lora_alpha:
|
||||
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
|
||||
if sd_lora_ckpt != "None":
|
||||
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
|
||||
config["models"]["lora_alphas"].append(lora_alpha)
|
||||
|
||||
|
||||
with st.expander("Data", expanded=True):
|
||||
with st.container(border=True):
|
||||
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
|
||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
||||
with column_height:
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
||||
with column_width:
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
||||
with column_start_frame_index:
|
||||
start_frame_id = st.number_input("Start Frame id", value=0)
|
||||
with column_end_frame_index:
|
||||
end_frame_id = st.number_input("End Frame id", value=16)
|
||||
if input_video != "":
|
||||
config["data"]["input_frames"] = {
|
||||
"video_file": input_video,
|
||||
"image_folder": None,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"start_frame_id": start_frame_id,
|
||||
"end_frame_id": end_frame_id
|
||||
}
|
||||
with st.container(border=True):
|
||||
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
|
||||
fps = st.number_input("FPS", value=60)
|
||||
config["data"]["output_folder"] = output_video
|
||||
config["data"]["fps"] = fps
|
||||
|
||||
|
||||
with st.expander("ControlNet Units", expanded=True):
|
||||
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
|
||||
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
|
||||
for controlnet_id in range(len(controlnet_units)):
|
||||
with controlnet_units[controlnet_id]:
|
||||
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
|
||||
key=f"controlnet_ckpt_{controlnet_id}")
|
||||
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
|
||||
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
|
||||
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
|
||||
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
|
||||
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
|
||||
disabled=controlnet_ckpt == "None",
|
||||
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
|
||||
if not use_input_video_as_controlnet_input:
|
||||
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
|
||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
||||
with column_height:
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
|
||||
with column_width:
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
|
||||
with column_start_frame_index:
|
||||
start_frame_id = st.number_input("Start Frame id", value=0,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
|
||||
with column_end_frame_index:
|
||||
end_frame_id = st.number_input("End Frame id", value=16,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
|
||||
if input_video != "":
|
||||
config["data"]["input_video"] = {
|
||||
"video_file": input_video,
|
||||
"image_folder": None,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"start_frame_id": start_frame_id,
|
||||
"end_frame_id": end_frame_id
|
||||
}
|
||||
if controlnet_ckpt != "None":
|
||||
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
|
||||
config["models"]["controlnet_units"].append({
|
||||
"processor_id": processor_id,
|
||||
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
|
||||
"scale": controlnet_scale,
|
||||
})
|
||||
if use_input_video_as_controlnet_input:
|
||||
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
|
||||
else:
|
||||
config["data"]["controlnet_frames"].append({
|
||||
"video_file": input_video,
|
||||
"image_folder": None,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"start_frame_id": start_frame_id,
|
||||
"end_frame_id": end_frame_id
|
||||
})
|
||||
|
||||
|
||||
with st.container(border=True):
|
||||
with st.expander("Seed", expanded=True):
|
||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
||||
if use_fixed_seed:
|
||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
||||
else:
|
||||
seed = np.random.randint(0, 10**9)
|
||||
with st.expander("Textual Guidance", expanded=True):
|
||||
prompt = st.text_area("Positive prompt")
|
||||
negative_prompt = st.text_area("Negative prompt")
|
||||
column_cfg_scale, column_clip_skip = st.columns(2)
|
||||
with column_cfg_scale:
|
||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
|
||||
with column_clip_skip:
|
||||
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
|
||||
with st.expander("Denoising", expanded=True):
|
||||
column_num_inference_steps, column_denoising_strength = st.columns(2)
|
||||
with column_num_inference_steps:
|
||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
|
||||
with column_denoising_strength:
|
||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
|
||||
with st.expander("Efficiency", expanded=False):
|
||||
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
|
||||
animatediff_stride = st.slider("Animatediff stride",
|
||||
min_value=1,
|
||||
max_value=max(2, animatediff_batch_size),
|
||||
value=max(1, animatediff_batch_size // 2),
|
||||
step=1)
|
||||
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
|
||||
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
|
||||
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
|
||||
config["pipeline"]["seed"] = seed
|
||||
config["pipeline"]["pipeline_inputs"] = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"clip_skip": clip_skip,
|
||||
"denoising_strength": denoising_strength,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"animatediff_batch_size": animatediff_batch_size,
|
||||
"animatediff_stride": animatediff_stride,
|
||||
"unet_batch_size": unet_batch_size,
|
||||
"controlnet_batch_size": controlnet_batch_size,
|
||||
"cross_frame_attention": cross_frame_attention,
|
||||
}
|
||||
|
||||
run_button = st.button("☢️Run☢️", type="primary")
|
||||
if run_button:
|
||||
SDVideoPipelineRunner(in_streamlit=True).run(config)
|
||||
@@ -1 +1,6 @@
|
||||
from .core import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .prompters import *
|
||||
from .schedulers import *
|
||||
from .pipelines import *
|
||||
from .controlnets import *
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
from .model_configs import MODEL_CONFIGS
|
||||
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
|
||||
|
||||
358
diffsynth/configs/model_config.py
Normal file
358
diffsynth/configs/model_config.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
from ..models.sd_text_encoder import SDTextEncoder
|
||||
from ..models.sd_unet import SDUNet
|
||||
from ..models.sd_vae_encoder import SDVAEEncoder
|
||||
from ..models.sd_vae_decoder import SDVAEDecoder
|
||||
|
||||
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from ..models.sdxl_unet import SDXLUNet
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from ..models.sd3_dit import SD3DiT
|
||||
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
||||
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from ..models.sd_controlnet import SDControlNet
|
||||
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
||||
|
||||
from ..models.sd_motion import SDMotionModel
|
||||
from ..models.sdxl_motion import SDXLMotionModel
|
||||
|
||||
from ..models.svd_image_encoder import SVDImageEncoder
|
||||
from ..models.svd_unet import SVDUNet
|
||||
from ..models.svd_vae_decoder import SVDVAEDecoder
|
||||
from ..models.svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
|
||||
from ..models.flux_dit import FluxDiT
|
||||
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
|
||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from ..models.cog_dit import CogDiT
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
||||
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
||||
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
||||
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
||||
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
||||
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
||||
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
||||
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
||||
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
||||
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
||||
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
||||
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
||||
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
||||
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
||||
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
||||
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
||||
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
||||
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
||||
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
||||
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
||||
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
|
||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
||||
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
||||
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
||||
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
||||
]
|
||||
|
||||
preset_models_on_huggingface = {
|
||||
"HunyuanDiT": [
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
# FLUX
|
||||
"FLUX.1-dev": [
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
],
|
||||
}
|
||||
preset_models_on_modelscope = {
|
||||
# Hunyuan DiT
|
||||
"HunyuanDiT": [
|
||||
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
||||
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
||||
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
||||
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
||||
],
|
||||
# Stable Video Diffusion
|
||||
"stable-video-diffusion-img2vid-xt": [
|
||||
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# ExVideo
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# Stable Diffusion
|
||||
"StableDiffusion_v15": [
|
||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"DreamShaper_8": [
|
||||
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"AingDiffusion_v12": [
|
||||
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"Flat2DAnimerge_v45Sharp": [
|
||||
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
# Textual Inversion
|
||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||
],
|
||||
# Stable Diffusion XL
|
||||
"StableDiffusionXL_v1": [
|
||||
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"BluePencilXL_v200": [
|
||||
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"StableDiffusionXL_Turbo": [
|
||||
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||
],
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
||||
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
||||
],
|
||||
# Stable Diffusion 3
|
||||
"StableDiffusion3": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
"StableDiffusion3_without_T5": [
|
||||
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
# ControlNet
|
||||
"ControlNet_v11f1p_sd15_depth": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11p_sd15_softedge": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11f1e_sd15_tile": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||
],
|
||||
"ControlNet_v11p_sd15_lineart": [
|
||||
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_union_sdxl_promax": [
|
||||
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
"AnimateDiff_xl_beta": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# Qwen Prompt
|
||||
"QwenPrompt": [
|
||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
],
|
||||
# Beautiful Prompt
|
||||
"BeautifulPrompt": [
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
# Omost prompt
|
||||
"OmostPrompt":[
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
],
|
||||
|
||||
# Translator
|
||||
"opus-mt-zh-en": [
|
||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
# IP-Adapter
|
||||
"IP-Adapter-SD": [
|
||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||
],
|
||||
"IP-Adapter-SDXL": [
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
# FLUX
|
||||
"FLUX.1-dev": [
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
],
|
||||
# ESRGAN
|
||||
"ESRGAN_x4": [
|
||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||
],
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# CogVideo
|
||||
"CogVideoX-5B": [
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||
],
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
"stable-video-diffusion-img2vid-xt",
|
||||
"ExVideo-SVD-128f-v1",
|
||||
"StableDiffusion_v15",
|
||||
"DreamShaper_8",
|
||||
"AingDiffusion_v12",
|
||||
"Flat2DAnimerge_v45Sharp",
|
||||
"TextualInversion_VeryBadImageNegative_v1.3",
|
||||
"StableDiffusionXL_v1",
|
||||
"BluePencilXL_v200",
|
||||
"StableDiffusionXL_Turbo",
|
||||
"ControlNet_v11f1p_sd15_depth",
|
||||
"ControlNet_v11p_sd15_softedge",
|
||||
"ControlNet_v11f1e_sd15_tile",
|
||||
"ControlNet_v11p_sd15_lineart",
|
||||
"AnimateDiff_v2",
|
||||
"AnimateDiff_xl_beta",
|
||||
"RIFE",
|
||||
"BeautifulPrompt",
|
||||
"opus-mt-zh-en",
|
||||
"IP-Adapter-SD",
|
||||
"IP-Adapter-SDXL",
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
"SDXL-vae-fp16-fix",
|
||||
"ControlNet_union_sdxl_promax",
|
||||
"FLUX.1-dev",
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||
"QwenPrompt",
|
||||
"OmostPrompt",
|
||||
"ESRGAN_x4",
|
||||
"RIFE",
|
||||
"CogVideoX-5B",
|
||||
]
|
||||
@@ -1,738 +0,0 @@
|
||||
qwen_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
|
||||
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
|
||||
"model_name": "qwen_image_dit",
|
||||
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "8004730443f55db63092006dd9f7110e",
|
||||
"model_name": "qwen_image_text_encoder",
|
||||
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
|
||||
"model_name": "qwen_image_vae",
|
||||
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
|
||||
"model_name": "qwen_image_blockwise_controlnet",
|
||||
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
|
||||
"model_name": "qwen_image_blockwise_controlnet",
|
||||
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
||||
"extra_kwargs": {"additional_in_dim": 4},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
|
||||
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
|
||||
"model_name": "siglip2_image_encoder",
|
||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
|
||||
"model_hash": "5722b5c873720009de96422993b15682",
|
||||
"model_name": "dinov3_image_encoder",
|
||||
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
||||
},
|
||||
{
|
||||
# Example:
|
||||
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
|
||||
"model_name": "qwen_image_image2lora_coarse",
|
||||
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||
},
|
||||
{
|
||||
# Example:
|
||||
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
|
||||
"model_name": "qwen_image_image2lora_fine",
|
||||
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
|
||||
},
|
||||
{
|
||||
# Example:
|
||||
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
|
||||
"model_name": "qwen_image_image2lora_style",
|
||||
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
||||
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
|
||||
"model_name": "qwen_image_dit",
|
||||
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
||||
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
|
||||
"model_name": "qwen_image_vae",
|
||||
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
||||
"extra_kwargs": {"image_channels": 4}
|
||||
},
|
||||
]
|
||||
|
||||
wan_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
|
||||
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
|
||||
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
|
||||
"model_name": "wan_video_text_encoder",
|
||||
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
|
||||
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
|
||||
"model_name": "wan_video_vae",
|
||||
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
||||
"model_name": "wan_video_vap",
|
||||
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
||||
"model_hash": "5941c53e207d62f20f9025686193c40b",
|
||||
"model_name": "wan_video_image_encoder",
|
||||
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "dbd5ec76bbf977983f972c151d545389",
|
||||
"model_name": "wan_video_motion_controller",
|
||||
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "9269f8db9040a9d860eaca435be61814",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "349723183fc063b2bfc10bb2835cf677",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "70ddad9d3a133785da5ea371aae09504",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
||||
"model_name": "wan_video_vace",
|
||||
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
||||
"model_name": "wan_video_vace",
|
||||
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
||||
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
||||
"model_name": "wan_video_animate_adapter",
|
||||
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "2267d489f0ceb9f21836532952852ee5",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "966cffdcc52f9c46c391768b27637614",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
|
||||
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
|
||||
"model_name": "wan_video_dit",
|
||||
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
||||
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
|
||||
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
|
||||
"model_name": "wan_video_vae",
|
||||
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
|
||||
"model_hash": "06be60f3a4526586d8431cd038a71486",
|
||||
"model_name": "wans2v_audio_encoder",
|
||||
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
flux_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
||||
"model_hash": "a29710fea6dddb0314663ee823598e50",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Supported due to historical reasons.
|
||||
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
||||
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||
"model_name": "flux_text_encoder_clip",
|
||||
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
|
||||
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
|
||||
"model_name": "flux_text_encoder_t5",
|
||||
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||
"model_name": "flux_vae_encoder",
|
||||
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
||||
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
||||
"model_name": "flux_vae_decoder",
|
||||
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
|
||||
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
||||
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
||||
"model_name": "flux_value_controller",
|
||||
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "52357cb26250681367488a8954c271e8",
|
||||
"model_name": "flux_controlnet",
|
||||
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "78d18b9101345ff695f312e7e62538c0",
|
||||
"model_name": "flux_controlnet",
|
||||
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "b001c89139b5f053c715fe772362dd2a",
|
||||
"model_name": "flux_controlnet",
|
||||
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||
"extra_kwargs": {"num_single_blocks": 0},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
|
||||
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
|
||||
"model_name": "infiniteyou_image_projector",
|
||||
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
|
||||
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
|
||||
"model_name": "flux_controlnet",
|
||||
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
|
||||
"model_name": "flux_lora_encoder",
|
||||
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
|
||||
"model_name": "flux_lora_patcher",
|
||||
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
|
||||
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
|
||||
"model_name": "nexus_gen_llm",
|
||||
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||
"model_name": "nexus_gen_editing_adapter",
|
||||
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||
"model_name": "nexus_gen_generation_adapter",
|
||||
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
|
||||
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
|
||||
"model_name": "flux_ipadapter",
|
||||
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
|
||||
"model_name": "siglip_vision_model",
|
||||
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||
"model_name": "step1x_connector",
|
||||
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
|
||||
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
flux2_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
||||
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
|
||||
"model_name": "flux2_text_encoder",
|
||||
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
|
||||
"model_name": "flux2_dit",
|
||||
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "c54288e3ee12ca215898840682337b95",
|
||||
"model_name": "flux2_vae",
|
||||
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
||||
"model_name": "flux2_dit",
|
||||
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
|
||||
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
"extra_kwargs": {"model_size": "8B"},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
|
||||
"model_name": "flux2_dit",
|
||||
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
||||
},
|
||||
]
|
||||
|
||||
z_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
||||
"model_name": "z_image_dit",
|
||||
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
||||
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||
"model_name": "flux_vae_encoder",
|
||||
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
||||
"extra_kwargs": {"use_conv_attention": False},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
||||
"model_name": "flux_vae_decoder",
|
||||
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||
"extra_kwargs": {"use_conv_attention": False},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
|
||||
"model_name": "z_image_dit",
|
||||
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||
"extra_kwargs": {"siglip_feat_dim": 1152},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
|
||||
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
|
||||
"model_name": "siglip_vision_model_428m",
|
||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
||||
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
||||
"model_name": "z_image_controlnet",
|
||||
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
||||
},
|
||||
{
|
||||
# Example: ???
|
||||
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||
"model_name": "z_image_image2lora_style",
|
||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||
"extra_kwargs": {"compress_dim": 128},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "1392adecee344136041e70553f875f31",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
"extra_kwargs": {"model_size": "0.6B"},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||
},
|
||||
]
|
||||
"""
|
||||
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
||||
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
||||
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
||||
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
||||
and avoid redundant memory usage when users only want to use part of the model.
|
||||
"""
|
||||
ltx2_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_dit",
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
|
||||
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
|
||||
"model_name": "ltx2_dit",
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
|
||||
"model_hash": "7f7e904a53260ec0351b05f32153754b",
|
||||
"model_name": "ltx2_video_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
|
||||
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
|
||||
"model_name": "ltx2_video_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
|
||||
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
||||
"model_name": "ltx2_audio_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vocoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
|
||||
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
|
||||
"model_name": "ltx2_audio_vocoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
|
||||
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
||||
"model_name": "ltx2_audio_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_text_encoder_post_modules",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
|
||||
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
|
||||
"model_name": "ltx2_text_encoder_post_modules",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
||||
"model_hash": "33917f31c4a79196171154cca39f165e",
|
||||
"model_name": "ltx2_text_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
||||
"model_name": "ltx2_latent_upsampler",
|
||||
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||
},
|
||||
]
|
||||
anima_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
|
||||
"model_hash": "a9995952c2d8e63cf82e115005eb61b9",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
"extra_kwargs": {"model_size": "0.6B"},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
|
||||
"model_hash": "417673936471e79e31ed4d186d7a3f4a",
|
||||
"model_name": "anima_dit",
|
||||
"model_class": "diffsynth.models.anima_dit.AnimaDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
|
||||
}
|
||||
]
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series
|
||||
@@ -1,266 +0,0 @@
|
||||
flux_general_vram_config = {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
}
|
||||
|
||||
VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
|
||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
|
||||
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
|
||||
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
||||
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
|
||||
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_dit.WanModel": {
|
||||
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
|
||||
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_mot.MotWanModel": {
|
||||
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_vace.VaceWanModel": {
|
||||
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_vae.WanVideoVAE": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux_dit.FluxDiT": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
|
||||
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
|
||||
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
|
||||
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
|
||||
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
|
||||
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
|
||||
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
|
||||
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
|
||||
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
|
||||
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
|
||||
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux2_dit.Flux2DiT": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux2_vae.Flux2VAE": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_dit.ZImageDiT": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
||||
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.ltx2_dit.LTXModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.anima_dit.AnimaDiT": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
def QwenImageTextEncoder_Module_Map_Updater():
|
||||
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
|
||||
from packaging import version
|
||||
import transformers
|
||||
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
|
||||
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
|
||||
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
|
||||
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
|
||||
return current
|
||||
|
||||
VERSION_CHECKER_MAPS = {
|
||||
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
|
||||
}
|
||||
2
diffsynth/controlnets/__init__.py
Normal file
2
diffsynth/controlnets/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
||||
from .processors import Annotator
|
||||
62
diffsynth/controlnets/controlnet_unit.py
Normal file
62
diffsynth/controlnets/controlnet_unit.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from .processors import Processor_id
|
||||
|
||||
|
||||
class ControlNetConfigUnit:
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
||||
self.processor_id = processor_id
|
||||
self.model_path = model_path
|
||||
self.scale = scale
|
||||
|
||||
|
||||
class ControlNetUnit:
|
||||
def __init__(self, processor, model, scale=1.0):
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.scale = scale
|
||||
|
||||
|
||||
class MultiControlNetManager:
|
||||
def __init__(self, controlnet_units=[]):
|
||||
self.processors = [unit.processor for unit in controlnet_units]
|
||||
self.models = [unit.model for unit in controlnet_units]
|
||||
self.scales = [unit.scale for unit in controlnet_units]
|
||||
|
||||
def cpu(self):
|
||||
for model in self.models:
|
||||
model.cpu()
|
||||
|
||||
def to(self, device):
|
||||
for model in self.models:
|
||||
model.to(device)
|
||||
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
processed_image = [processor(image) for processor in self.processors]
|
||||
else:
|
||||
processed_image = [self.processors[processor_id](image)]
|
||||
processed_image = torch.concat([
|
||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
||||
for image_ in processed_image
|
||||
], dim=0)
|
||||
return processed_image
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditionings,
|
||||
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
||||
):
|
||||
res_stack = None
|
||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
||||
res_stack_ = model(
|
||||
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
processor_id=processor.processor_id
|
||||
)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
return res_stack
|
||||
51
diffsynth/controlnets/processors.py
Normal file
51
diffsynth/controlnets/processors.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
def __call__(self, image):
|
||||
width, height = image.size
|
||||
if self.processor_id == "openpose":
|
||||
kwargs = {
|
||||
"include_body": True,
|
||||
"include_hand": True,
|
||||
"include_face": True
|
||||
}
|
||||
else:
|
||||
kwargs = {}
|
||||
if self.processor is not None:
|
||||
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
||||
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
||||
image = image.resize((width, height))
|
||||
return image
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from .attention import *
|
||||
from .data import *
|
||||
from .gradient import *
|
||||
from .loader import *
|
||||
from .vram import *
|
||||
from .device import *
|
||||
@@ -1 +0,0 @@
|
||||
from .attention import attention_forward
|
||||
@@ -1,121 +0,0 @@
|
||||
import torch, os
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTN_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
def initialize_attention_priority():
|
||||
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
||||
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
||||
elif FLASH_ATTN_3_AVAILABLE:
|
||||
return "flash_attention_3"
|
||||
elif FLASH_ATTN_2_AVAILABLE:
|
||||
return "flash_attention_2"
|
||||
elif SAGE_ATTN_AVAILABLE:
|
||||
return "sage_attention"
|
||||
elif XFORMERS_AVAILABLE:
|
||||
return "xformers"
|
||||
else:
|
||||
return "torch"
|
||||
|
||||
|
||||
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
||||
|
||||
|
||||
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
||||
dims = {} if dims is None else dims
|
||||
if q_pattern != required_in_pattern:
|
||||
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||
if k_pattern != required_in_pattern:
|
||||
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||
if v_pattern != required_in_pattern:
|
||||
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
||||
dims = {} if dims is None else dims
|
||||
if out_pattern != required_out_pattern:
|
||||
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
||||
return out
|
||||
|
||||
|
||||
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = sageattn(q, k, v, sm_scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
||||
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
||||
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
||||
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
||||
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
||||
return out
|
||||
|
||||
|
||||
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
||||
if compatibility_mode or (attn_mask is not None):
|
||||
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
||||
else:
|
||||
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
||||
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
||||
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
||||
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
elif ATTENTION_IMPLEMENTATION == "xformers":
|
||||
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
else:
|
||||
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
||||
@@ -1 +0,0 @@
|
||||
from .unified_dataset import UnifiedDataset
|
||||
@@ -1,237 +0,0 @@
|
||||
import torch, torchvision, imageio, os
|
||||
import imageio.v3 as iio
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DataProcessingPipeline:
|
||||
def __init__(self, operators=None):
|
||||
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
||||
|
||||
def __call__(self, data):
|
||||
for operator in self.operators:
|
||||
data = operator(data)
|
||||
return data
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline(self.operators + pipe.operators)
|
||||
|
||||
|
||||
class DataProcessingOperator:
|
||||
def __call__(self, data):
|
||||
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
||||
|
||||
def __rshift__(self, pipe):
|
||||
if isinstance(pipe, DataProcessingOperator):
|
||||
pipe = DataProcessingPipeline([pipe])
|
||||
return DataProcessingPipeline([self]).__rshift__(pipe)
|
||||
|
||||
|
||||
class DataProcessingOperatorRaw(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return data
|
||||
|
||||
|
||||
class ToInt(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return int(data)
|
||||
|
||||
|
||||
class ToFloat(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return float(data)
|
||||
|
||||
|
||||
class ToStr(DataProcessingOperator):
|
||||
def __init__(self, none_value=""):
|
||||
self.none_value = none_value
|
||||
|
||||
def __call__(self, data):
|
||||
if data is None: data = self.none_value
|
||||
return str(data)
|
||||
|
||||
|
||||
class LoadImage(DataProcessingOperator):
|
||||
def __init__(self, convert_RGB=True, convert_RGBA=False):
|
||||
self.convert_RGB = convert_RGB
|
||||
self.convert_RGBA = convert_RGBA
|
||||
|
||||
def __call__(self, data: str):
|
||||
image = Image.open(data)
|
||||
if self.convert_RGB: image = image.convert("RGB")
|
||||
if self.convert_RGBA: image = image.convert("RGBA")
|
||||
return image
|
||||
|
||||
|
||||
class ImageCropAndResize(DataProcessingOperator):
|
||||
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.max_pixels = max_pixels
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
def get_height_width(self, image):
|
||||
if self.height is None or self.width is None:
|
||||
width, height = image.size
|
||||
if width * height > self.max_pixels:
|
||||
scale = (width * height / self.max_pixels) ** 0.5
|
||||
height, width = int(height / scale), int(width / scale)
|
||||
height = height // self.height_division_factor * self.height_division_factor
|
||||
width = width // self.width_division_factor * self.width_division_factor
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
return height, width
|
||||
|
||||
def __call__(self, data: Image.Image):
|
||||
image = self.crop_and_resize(data, *self.get_height_width(data))
|
||||
return image
|
||||
|
||||
|
||||
class ToList(DataProcessingOperator):
|
||||
def __call__(self, data):
|
||||
return [data]
|
||||
|
||||
|
||||
class LoadVideo(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, reader):
|
||||
num_frames = self.num_frames
|
||||
if int(reader.count_frames()) < num_frames:
|
||||
num_frames = int(reader.count_frames())
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
reader = imageio.get_reader(data)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
frames = []
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(frame_id)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
return frames
|
||||
|
||||
|
||||
class SequencialProcess(DataProcessingOperator):
|
||||
def __init__(self, operator=lambda x: x):
|
||||
self.operator = operator
|
||||
|
||||
def __call__(self, data):
|
||||
return [self.operator(i) for i in data]
|
||||
|
||||
|
||||
class LoadGIF(DataProcessingOperator):
|
||||
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
||||
self.num_frames = num_frames
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# frame_processor is build in the video loader for high efficiency.
|
||||
self.frame_processor = frame_processor
|
||||
|
||||
def get_num_frames(self, path):
|
||||
num_frames = self.num_frames
|
||||
images = iio.imread(path, mode="RGB")
|
||||
if len(images) < num_frames:
|
||||
num_frames = len(images)
|
||||
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames -= 1
|
||||
return num_frames
|
||||
|
||||
def __call__(self, data: str):
|
||||
num_frames = self.get_num_frames(data)
|
||||
frames = []
|
||||
images = iio.imread(data, mode="RGB")
|
||||
for img in images:
|
||||
frame = Image.fromarray(img)
|
||||
frame = self.frame_processor(frame)
|
||||
frames.append(frame)
|
||||
if len(frames) >= num_frames:
|
||||
break
|
||||
return frames
|
||||
|
||||
|
||||
class RouteByExtensionName(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data: str):
|
||||
file_ext_name = data.split(".")[-1].lower()
|
||||
for ext_names, operator in self.operator_map:
|
||||
if ext_names is None or file_ext_name in ext_names:
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported file: {data}")
|
||||
|
||||
|
||||
class RouteByType(DataProcessingOperator):
|
||||
def __init__(self, operator_map):
|
||||
self.operator_map = operator_map
|
||||
|
||||
def __call__(self, data):
|
||||
for dtype, operator in self.operator_map:
|
||||
if dtype is None or isinstance(data, dtype):
|
||||
return operator(data)
|
||||
raise ValueError(f"Unsupported data: {data}")
|
||||
|
||||
|
||||
class LoadTorchPickle(DataProcessingOperator):
|
||||
def __init__(self, map_location="cpu"):
|
||||
self.map_location = map_location
|
||||
|
||||
def __call__(self, data):
|
||||
return torch.load(data, map_location=self.map_location, weights_only=False)
|
||||
|
||||
|
||||
class ToAbsolutePath(DataProcessingOperator):
|
||||
def __init__(self, base_path=""):
|
||||
self.base_path = base_path
|
||||
|
||||
def __call__(self, data):
|
||||
return os.path.join(self.base_path, data)
|
||||
|
||||
|
||||
class LoadAudio(DataProcessingOperator):
|
||||
def __init__(self, sr=16000):
|
||||
self.sr = sr
|
||||
def __call__(self, data: str):
|
||||
import librosa
|
||||
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
||||
return input_audio
|
||||
|
||||
|
||||
class LoadAudioWithTorchaudio(DataProcessingOperator):
|
||||
def __init__(self, duration=5):
|
||||
self.duration = duration
|
||||
|
||||
def __call__(self, data: str):
|
||||
import torchaudio
|
||||
waveform, sample_rate = torchaudio.load(data)
|
||||
target_samples = int(self.duration * sample_rate)
|
||||
current_samples = waveform.shape[-1]
|
||||
if current_samples > target_samples:
|
||||
waveform = waveform[..., :target_samples]
|
||||
elif current_samples < target_samples:
|
||||
padding = target_samples - current_samples
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
return waveform, sample_rate
|
||||
@@ -1,116 +0,0 @@
|
||||
from .operators import *
|
||||
import torch, json, pandas
|
||||
|
||||
|
||||
class UnifiedDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
base_path=None, metadata_path=None,
|
||||
repeat=1,
|
||||
data_file_keys=tuple(),
|
||||
main_data_operator=lambda x: x,
|
||||
special_operator_map=None,
|
||||
max_data_items=None,
|
||||
):
|
||||
self.base_path = base_path
|
||||
self.metadata_path = metadata_path
|
||||
self.repeat = repeat
|
||||
self.data_file_keys = data_file_keys
|
||||
self.main_data_operator = main_data_operator
|
||||
self.cached_data_operator = LoadTorchPickle()
|
||||
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||
self.max_data_items = max_data_items
|
||||
self.data = []
|
||||
self.cached_data = []
|
||||
self.load_from_cache = metadata_path is None
|
||||
self.load_metadata(metadata_path)
|
||||
|
||||
@staticmethod
|
||||
def default_image_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
||||
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def default_video_operator(
|
||||
base_path="",
|
||||
max_pixels=1920*1080, height=None, width=None,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
||||
):
|
||||
return RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||
(("gif",), LoadGIF(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
])),
|
||||
])
|
||||
|
||||
def search_for_cached_data_files(self, path):
|
||||
for file_name in os.listdir(path):
|
||||
subpath = os.path.join(path, file_name)
|
||||
if os.path.isdir(subpath):
|
||||
self.search_for_cached_data_files(subpath)
|
||||
elif subpath.endswith(".pth"):
|
||||
self.cached_data.append(subpath)
|
||||
|
||||
def load_metadata(self, metadata_path):
|
||||
if metadata_path is None:
|
||||
print("No metadata_path. Searching for cached data files.")
|
||||
self.search_for_cached_data_files(self.base_path)
|
||||
print(f"{len(self.cached_data)} cached data files found.")
|
||||
elif metadata_path.endswith(".json"):
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
self.data = metadata
|
||||
elif metadata_path.endswith(".jsonl"):
|
||||
metadata = []
|
||||
with open(metadata_path, 'r') as f:
|
||||
for line in f:
|
||||
metadata.append(json.loads(line.strip()))
|
||||
self.data = metadata
|
||||
else:
|
||||
metadata = pandas.read_csv(metadata_path)
|
||||
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
||||
|
||||
def __getitem__(self, data_id):
|
||||
if self.load_from_cache:
|
||||
data = self.cached_data[data_id % len(self.cached_data)]
|
||||
data = self.cached_data_operator(data)
|
||||
else:
|
||||
data = self.data[data_id % len(self.data)].copy()
|
||||
for key in self.data_file_keys:
|
||||
if key in data:
|
||||
if key in self.special_operator_map:
|
||||
data[key] = self.special_operator_map[key](data[key])
|
||||
elif key in self.data_file_keys:
|
||||
data[key] = self.main_data_operator(data[key])
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
if self.max_data_items is not None:
|
||||
return self.max_data_items
|
||||
elif self.load_from_cache:
|
||||
return len(self.cached_data) * self.repeat
|
||||
else:
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
def check_data_equal(self, data1, data2):
|
||||
# Debug only
|
||||
if len(data1) != len(data2):
|
||||
return False
|
||||
for k in data1:
|
||||
if data1[k] != data2[k]:
|
||||
return False
|
||||
return True
|
||||
@@ -1,2 +0,0 @@
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||
@@ -1,107 +0,0 @@
|
||||
import importlib
|
||||
import torch
|
||||
from typing import Any
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
return importlib.util.find_spec("torch_npu") is not None
|
||||
|
||||
|
||||
IS_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
|
||||
def get_device_type() -> str:
|
||||
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
device = "cuda"
|
||||
elif IS_NPU_AVAILABLE:
|
||||
device = "npu"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_torch_device() -> Any:
|
||||
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
|
||||
device_name = get_device_type()
|
||||
|
||||
try:
|
||||
return getattr(torch, device_name)
|
||||
except AttributeError:
|
||||
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
|
||||
return torch.cuda
|
||||
|
||||
|
||||
def get_device_id() -> int:
|
||||
"""Get current device id based on device type."""
|
||||
return get_torch_device().current_device()
|
||||
|
||||
|
||||
def get_device_name() -> str:
|
||||
"""Get current device name based on device type."""
|
||||
return f"{get_device_type()}:{get_device_id()}"
|
||||
|
||||
|
||||
def synchronize() -> None:
|
||||
"""Execute torch synchronize operation."""
|
||||
get_torch_device().synchronize()
|
||||
|
||||
|
||||
def empty_cache() -> None:
|
||||
"""Execute torch empty cache operation."""
|
||||
get_torch_device().empty_cache()
|
||||
|
||||
|
||||
def get_nccl_backend() -> str:
|
||||
"""Return distributed communication backend type based on device type."""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
return "nccl"
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return "hccl"
|
||||
else:
|
||||
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
|
||||
|
||||
|
||||
def enable_high_precision_for_bf16():
|
||||
"""
|
||||
Set high accumulation dtype for matmul and reduction.
|
||||
"""
|
||||
if IS_CUDA_AVAILABLE:
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.matmul.allow_tf32 = False
|
||||
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
|
||||
def parse_device_type(device):
|
||||
if isinstance(device, str):
|
||||
if device.startswith("cuda"):
|
||||
return "cuda"
|
||||
elif device.startswith("npu"):
|
||||
return "npu"
|
||||
else:
|
||||
return "cpu"
|
||||
elif isinstance(device, torch.device):
|
||||
return device.type
|
||||
|
||||
|
||||
def parse_nccl_backend(device_type):
|
||||
if device_type == "cuda":
|
||||
return "nccl"
|
||||
elif device_type == "npu":
|
||||
return "hccl"
|
||||
else:
|
||||
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
|
||||
|
||||
|
||||
def get_available_device_type():
|
||||
return get_device_type()
|
||||
@@ -1 +0,0 @@
|
||||
from .gradient_checkpoint import gradient_checkpoint_forward
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
return model_output
|
||||
@@ -1,3 +0,0 @@
|
||||
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
||||
from .model import load_model, load_model_with_disk_offload
|
||||
from .config import ModelConfig
|
||||
@@ -1,119 +0,0 @@
|
||||
import torch, glob, os
|
||||
from typing import Optional, Union, Dict
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_source: str = None
|
||||
local_model_path: str = None
|
||||
skip_download: bool = None
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
offload_dtype: Optional[torch.dtype] = None
|
||||
onload_device: Optional[Union[str, torch.device]] = None
|
||||
onload_dtype: Optional[torch.dtype] = None
|
||||
preparing_device: Optional[Union[str, torch.device]] = None
|
||||
preparing_dtype: Optional[torch.dtype] = None
|
||||
computation_device: Optional[Union[str, torch.device]] = None
|
||||
computation_dtype: Optional[torch.dtype] = None
|
||||
clear_parameters: bool = False
|
||||
state_dict: Dict[str, torch.Tensor] = None
|
||||
|
||||
def check_input(self):
|
||||
if self.path is None and self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||
|
||||
def parse_original_file_pattern(self):
|
||||
if self.origin_file_pattern in [None, "", "./"]:
|
||||
return "*"
|
||||
elif self.origin_file_pattern.endswith("/"):
|
||||
return self.origin_file_pattern + "*"
|
||||
else:
|
||||
return self.origin_file_pattern
|
||||
|
||||
def parse_download_source(self):
|
||||
if self.download_source is None:
|
||||
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
||||
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
||||
else:
|
||||
return "modelscope"
|
||||
else:
|
||||
return self.download_source
|
||||
|
||||
def parse_skip_download(self):
|
||||
if self.skip_download is None:
|
||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
||||
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
||||
return True
|
||||
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return self.skip_download
|
||||
|
||||
def download(self):
|
||||
origin_file_pattern = self.parse_original_file_pattern()
|
||||
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
download_source = self.parse_download_source()
|
||||
if download_source.lower() == "modelscope":
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=origin_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
elif download_source.lower() == "huggingface":
|
||||
hf_snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_patterns=origin_file_pattern,
|
||||
ignore_patterns=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
else:
|
||||
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
||||
|
||||
def require_downloading(self):
|
||||
if self.path is not None:
|
||||
return False
|
||||
skip_download = self.parse_skip_download()
|
||||
return not skip_download
|
||||
|
||||
def reset_local_model_path(self):
|
||||
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
||||
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
||||
elif self.local_model_path is None:
|
||||
self.local_model_path = "./models"
|
||||
|
||||
def download_if_necessary(self):
|
||||
self.check_input()
|
||||
self.reset_local_model_path()
|
||||
if self.require_downloading():
|
||||
self.download()
|
||||
if self.path is None:
|
||||
if self.origin_file_pattern in [None, "", "./"]:
|
||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
self.path = self.path[0]
|
||||
|
||||
def vram_config(self):
|
||||
return {
|
||||
"offload_device": self.offload_device,
|
||||
"offload_dtype": self.offload_dtype,
|
||||
"onload_device": self.onload_device,
|
||||
"onload_dtype": self.onload_dtype,
|
||||
"preparing_device": self.preparing_device,
|
||||
"preparing_dtype": self.preparing_dtype,
|
||||
"computation_device": self.computation_device,
|
||||
"computation_dtype": self.computation_dtype,
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
from safetensors import safe_open
|
||||
import torch, hashlib
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||
else:
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [started]: {file_path}")
|
||||
if file_path.endswith(".safetensors"):
|
||||
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||
if pin_memory:
|
||||
for i in state_dict:
|
||||
state_dict[i] = state_dict[i].pin_memory()
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [done]: {file_path}")
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
if len(state_dict) == 1:
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
elif "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
elif "model_state" in state_dict:
|
||||
state_dict = state_dict["model_state"]
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
|
||||
def load_keys_dict(file_path):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_keys_dict(file_path_))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_keys_dict_from_safetensors(file_path)
|
||||
else:
|
||||
return load_keys_dict_from_bin(file_path)
|
||||
|
||||
|
||||
def load_keys_dict_from_safetensors(file_path):
|
||||
keys_dict = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
keys_dict[k] = f.get_slice(k).get_shape()
|
||||
return keys_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_keys_dict(state_dict):
|
||||
keys_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
keys_dict[k] = list(v.shape)
|
||||
else:
|
||||
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
||||
return keys_dict
|
||||
|
||||
|
||||
def load_keys_dict_from_bin(file_path):
|
||||
state_dict = load_state_dict_from_bin(file_path)
|
||||
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
||||
return keys_dict
|
||||
|
||||
|
||||
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
||||
else:
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def hash_model_file(path, with_shape=True):
|
||||
keys_dict = load_keys_dict(path)
|
||||
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
@@ -1,105 +0,0 @@
|
||||
from ..vram.initialization import skip_model_initialization
|
||||
from ..vram.disk_map import DiskMap
|
||||
from ..vram.layers import enable_vram_management
|
||||
from .file import load_state_dict
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||
config = {} if config is None else config
|
||||
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
|
||||
model = model_class(**config)
|
||||
# What is `module_map`?
|
||||
# This is a module mapping table for VRAM management.
|
||||
if module_map is not None:
|
||||
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
||||
device = [d for d in devices if d != "disk"][0]
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
||||
else:
|
||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
||||
else:
|
||||
# Why do we use `DiskMap`?
|
||||
# Sometimes a model file contains multiple models,
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if state_dict is not None:
|
||||
pass
|
||||
elif use_disk_map:
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
# Why do we use `state_dict_converter`?
|
||||
# Some models are saved in complex formats,
|
||||
# and we need to convert the state dict into the appropriate format.
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
state_dict = {i: state_dict[i] for i in state_dict}
|
||||
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
|
||||
# Because at this stage, model parameters are partitioned across multiple GPUs.
|
||||
# Loading them directly could lead to excessive GPU memory consumption.
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
_load_state_dict_into_zero3_model(model, state_dict)
|
||||
else:
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
# Why do we call `to()`?
|
||||
# Because some models override the behavior of `to()`,
|
||||
# especially those from libraries like Transformers.
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
config = {} if config is None else config
|
||||
with skip_model_initialization():
|
||||
model = model_class(**config)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch_dtype,
|
||||
"computation_device": device,
|
||||
}
|
||||
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
||||
return model
|
||||
|
||||
|
||||
def get_init_context(torch_dtype, device):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from transformers.modeling_utils import set_zero3_state
|
||||
import deepspeed
|
||||
# Why do we use "deepspeed.zero.Init"?
|
||||
# Weight segmentation of the model can be performed on the CPU side
|
||||
# and loading the segmented weights onto the computing card
|
||||
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
|
||||
else:
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
# thereby speeding up model loading and avoiding excessive memory usage.
|
||||
init_contexts = [skip_model_initialization()]
|
||||
|
||||
return init_contexts
|
||||
@@ -1,30 +0,0 @@
|
||||
import torch
|
||||
from ..device.npu_compatible_device import get_device_type
|
||||
try:
|
||||
import torch_npu
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def rms_norm_forward_npu(self, hidden_states):
|
||||
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
|
||||
if hidden_states.dtype != self.weight.dtype:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
|
||||
|
||||
|
||||
def rms_norm_forward_transformers_npu(self, hidden_states):
|
||||
"npu rms fused operator for transformers"
|
||||
if hidden_states.dtype != self.weight.dtype:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
|
||||
|
||||
|
||||
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
||||
"npu rope fused operator for Zimage"
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
|
||||
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
|
||||
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)
|
||||
@@ -1,2 +0,0 @@
|
||||
from .initialization import skip_model_initialization
|
||||
from .layers import *
|
||||
@@ -1,93 +0,0 @@
|
||||
from safetensors import safe_open
|
||||
import torch, os
|
||||
|
||||
|
||||
class SafetensorsCompatibleTensor:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def get_shape(self):
|
||||
return list(self.tensor.shape)
|
||||
|
||||
|
||||
class SafetensorsCompatibleBinaryLoader:
|
||||
def __init__(self, path, device):
|
||||
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
||||
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
||||
|
||||
def keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
def get_tensor(self, name):
|
||||
return self.state_dict[name]
|
||||
|
||||
def get_slice(self, name):
|
||||
return SafetensorsCompatibleTensor(self.state_dict[name])
|
||||
|
||||
|
||||
class DiskMap:
|
||||
|
||||
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
||||
self.path = path if isinstance(path, list) else [path]
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
||||
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
||||
else:
|
||||
self.buffer_size = buffer_size
|
||||
self.files = []
|
||||
self.flush_files()
|
||||
self.name_map = {}
|
||||
for file_id, file in enumerate(self.files):
|
||||
for name in file.keys():
|
||||
self.name_map[name] = file_id
|
||||
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
||||
|
||||
def flush_files(self):
|
||||
if len(self.files) == 0:
|
||||
for path in self.path:
|
||||
if path.endswith(".safetensors"):
|
||||
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
||||
else:
|
||||
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
||||
else:
|
||||
for i, path in enumerate(self.path):
|
||||
if path.endswith(".safetensors"):
|
||||
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
||||
self.num_params = 0
|
||||
|
||||
def __getitem__(self, name):
|
||||
if self.rename_dict is not None: name = self.rename_dict[name]
|
||||
file_id = self.name_map[name]
|
||||
param = self.files[file_id].get_tensor(name)
|
||||
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||
param = param.to(self.torch_dtype)
|
||||
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
||||
param = param.clone()
|
||||
if isinstance(param, torch.Tensor):
|
||||
self.num_params += param.numel()
|
||||
if self.num_params > self.buffer_size:
|
||||
self.flush_files()
|
||||
return param
|
||||
|
||||
def fetch_rename_dict(self, state_dict_converter):
|
||||
if state_dict_converter is None:
|
||||
return None
|
||||
state_dict = {}
|
||||
for file in self.files:
|
||||
for name in file.keys():
|
||||
state_dict[name] = name
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
return state_dict
|
||||
|
||||
def __iter__(self):
|
||||
if self.rename_dict is not None:
|
||||
return self.rename_dict.__iter__()
|
||||
else:
|
||||
return self.name_map.__iter__()
|
||||
|
||||
def __contains__(self, x):
|
||||
if self.rename_dict is not None:
|
||||
return x in self.rename_dict
|
||||
else:
|
||||
return x in self.name_map
|
||||
@@ -1,21 +0,0 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def skip_model_initialization(device=torch.device("meta")):
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
@@ -1,479 +0,0 @@
|
||||
import torch, copy
|
||||
from typing import Union
|
||||
from .initialization import skip_model_initialization
|
||||
from .disk_map import DiskMap
|
||||
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.set_dtype_and_device(
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
)
|
||||
self.state = 0
|
||||
self.name = ""
|
||||
self.computation_device_type = parse_device_type(self.computation_device)
|
||||
|
||||
def set_dtype_and_device(
|
||||
self,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
self.offload_dtype = offload_dtype or computation_dtype
|
||||
self.offload_device = offload_device or computation_dtype
|
||||
self.onload_dtype = onload_dtype or computation_dtype
|
||||
self.onload_device = onload_device or computation_dtype
|
||||
self.preparing_dtype = preparing_dtype or computation_dtype
|
||||
self.preparing_device = preparing_device or computation_dtype
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.vram_limit = vram_limit
|
||||
|
||||
def cast_to(self, weight, dtype, device):
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
def offload(self):
|
||||
if self.state != 0:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state != 1:
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def param_name(self, name):
|
||||
if self.name == "":
|
||||
return name
|
||||
else:
|
||||
return self.name + "." + name
|
||||
|
||||
|
||||
class AutoWrappedModule(AutoTorchModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
name: str = "",
|
||||
disk_map: DiskMap = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
)
|
||||
self.module = module
|
||||
if offload_dtype == "disk":
|
||||
self.name = name
|
||||
self.disk_map = disk_map
|
||||
self.required_params = [name for name, _ in self.module.named_parameters()]
|
||||
self.disk_offload = True
|
||||
else:
|
||||
self.disk_offload = False
|
||||
|
||||
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||
if copy_module:
|
||||
module = copy.deepcopy(self.module)
|
||||
else:
|
||||
module = self.module
|
||||
state_dict = {}
|
||||
for name in self.required_params:
|
||||
param = self.disk_map[self.param_name(name)]
|
||||
param = param.to(dtype=torch_dtype, device=device)
|
||||
state_dict[name] = param
|
||||
module.load_state_dict(state_dict, assign=True)
|
||||
module.to(dtype=torch_dtype, device=device)
|
||||
return module
|
||||
|
||||
def offload_to_disk(self, model: torch.nn.Module):
|
||||
for buf in model.buffers():
|
||||
# If there are some parameters are registed in buffers (not in state dict),
|
||||
# We cannot offload the model.
|
||||
for children in model.children():
|
||||
self.offload_to_disk(children)
|
||||
break
|
||||
else:
|
||||
model.to("meta")
|
||||
|
||||
def offload(self):
|
||||
# offload / onload / preparing -> offload
|
||||
if self.state != 0:
|
||||
if self.disk_offload:
|
||||
self.offload_to_disk(self.module)
|
||||
else:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
# offload / onload / preparing -> onload
|
||||
if self.state < 1:
|
||||
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||
elif self.onload_device != "disk":
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def preparing(self):
|
||||
# onload / preparing -> preparing
|
||||
if self.state != 2:
|
||||
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||
elif self.preparing_device != "disk":
|
||||
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||
self.state = 2
|
||||
|
||||
def cast_to(self, module, dtype, device):
|
||||
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
||||
|
||||
def computation(self):
|
||||
# onload / preparing -> computation (temporary)
|
||||
if self.state == 2:
|
||||
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||
else:
|
||||
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||
module = self.module
|
||||
elif self.disk_offload and device == "disk":
|
||||
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
||||
else:
|
||||
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||
self.preparing()
|
||||
module = self.computation()
|
||||
return module(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__ or name == "module":
|
||||
return super().__getattr__(name)
|
||||
else:
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
name: str = "",
|
||||
disk_map: DiskMap = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
module,
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
name,
|
||||
disk_map,
|
||||
**kwargs
|
||||
)
|
||||
if self.disk_offload:
|
||||
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
||||
|
||||
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
||||
if copy_module:
|
||||
module = copy.deepcopy(self.module)
|
||||
else:
|
||||
module = self.module
|
||||
state_dict = {}
|
||||
for name in self.required_params:
|
||||
param = self.disk_map[self.param_name(name)]
|
||||
param = param.to(dtype=torch_dtype, device=device)
|
||||
state_dict[name] = param
|
||||
module.load_state_dict(state_dict, assign=True, strict=False)
|
||||
return module
|
||||
|
||||
def offload_to_disk(self, model: torch.nn.Module):
|
||||
for name in self.required_params:
|
||||
getattr(self, name).to("meta")
|
||||
|
||||
def cast_to(self, module, dtype, device):
|
||||
# Parameter casting is implemented in the model architecture.
|
||||
return module
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__ or name == "module":
|
||||
return super().__getattr__(name)
|
||||
else:
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Linear,
|
||||
offload_dtype: torch.dtype = None,
|
||||
offload_device: Union[str, torch.device] = None,
|
||||
onload_dtype: torch.dtype = None,
|
||||
onload_device: Union[str, torch.device] = None,
|
||||
preparing_dtype: torch.dtype = None,
|
||||
preparing_device: Union[str, torch.device] = None,
|
||||
computation_dtype: torch.dtype = None,
|
||||
computation_device: Union[str, torch.device] = None,
|
||||
vram_limit: float = None,
|
||||
name: str = "",
|
||||
disk_map: DiskMap = None,
|
||||
**kwargs
|
||||
):
|
||||
with skip_model_initialization():
|
||||
super().__init__(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
)
|
||||
self.set_dtype_and_device(
|
||||
offload_dtype,
|
||||
offload_device,
|
||||
onload_dtype,
|
||||
onload_device,
|
||||
preparing_dtype,
|
||||
preparing_device,
|
||||
computation_dtype,
|
||||
computation_device,
|
||||
vram_limit,
|
||||
)
|
||||
self.weight = module.weight
|
||||
self.bias = module.bias
|
||||
self.state = 0
|
||||
self.name = name
|
||||
self.lora_A_weights = []
|
||||
self.lora_B_weights = []
|
||||
self.lora_merger = None
|
||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||
self.computation_device_type = parse_device_type(self.computation_device)
|
||||
|
||||
if offload_dtype == "disk":
|
||||
self.disk_map = disk_map
|
||||
self.disk_offload = True
|
||||
else:
|
||||
self.disk_offload = False
|
||||
|
||||
def fp8_linear(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
device = input.device
|
||||
origin_dtype = input.dtype
|
||||
origin_shape = input.shape
|
||||
input = input.reshape(-1, origin_shape[-1])
|
||||
|
||||
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||
fp8_max = 448.0
|
||||
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||
# we scale down the input by 2.0 in advance.
|
||||
# This scaling will be compensated later during the final result scaling.
|
||||
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||
fp8_max = fp8_max / 2.0
|
||||
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||
input = input / (scale_a + 1e-8)
|
||||
input = input.to(self.computation_dtype)
|
||||
weight = weight.to(self.computation_dtype)
|
||||
bias = bias.to(torch.bfloat16)
|
||||
|
||||
result = torch._scaled_mm(
|
||||
input,
|
||||
weight.T,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.T,
|
||||
bias=bias,
|
||||
out_dtype=origin_dtype,
|
||||
)
|
||||
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||
result = result.reshape(new_shape)
|
||||
return result
|
||||
|
||||
def load_from_disk(self, torch_dtype, device, assign=True):
|
||||
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
|
||||
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
|
||||
if assign:
|
||||
state_dict = {"weight": weight}
|
||||
if bias is not None: state_dict["bias"] = bias
|
||||
self.load_state_dict(state_dict, assign=True)
|
||||
return weight, bias
|
||||
|
||||
def offload(self):
|
||||
# offload / onload / preparing -> offload
|
||||
if self.state != 0:
|
||||
if self.disk_offload:
|
||||
self.to("meta")
|
||||
else:
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
# offload / onload / preparing -> onload
|
||||
if self.state < 1:
|
||||
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
||||
self.load_from_disk(self.onload_dtype, self.onload_device)
|
||||
elif self.onload_device != "disk":
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def preparing(self):
|
||||
# onload / preparing -> preparing
|
||||
if self.state != 2:
|
||||
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
||||
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
||||
elif self.preparing_device != "disk":
|
||||
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||
self.state = 2
|
||||
|
||||
def computation(self):
|
||||
# onload / preparing -> computation (temporary)
|
||||
if self.state == 2:
|
||||
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
||||
else:
|
||||
torch_dtype, device = self.onload_dtype, self.onload_device
|
||||
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
elif self.disk_offload and device == "disk":
|
||||
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
|
||||
else:
|
||||
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return weight, bias
|
||||
|
||||
def linear_forward(self, x, weight, bias):
|
||||
if self.enable_fp8:
|
||||
out = self.fp8_linear(x, weight, bias)
|
||||
else:
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
return out
|
||||
|
||||
def lora_forward(self, x, out):
|
||||
if self.lora_merger is None:
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
out = out + x @ lora_A.T @ lora_B.T
|
||||
else:
|
||||
lora_output = []
|
||||
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
||||
lora_output.append(x @ lora_A.T @ lora_B.T)
|
||||
lora_output = torch.stack(lora_output)
|
||||
out = self.lora_merger(out, lora_output)
|
||||
return out
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
||||
self.preparing()
|
||||
weight, bias = self.computation()
|
||||
out = self.linear_forward(x, weight, bias)
|
||||
if len(self.lora_A_weights) > 0:
|
||||
out = self.lora_forward(x, out)
|
||||
return out
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
||||
if isinstance(model, AutoWrappedNonRecurseModule):
|
||||
model = model.module
|
||||
for name, module in model.named_children():
|
||||
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
||||
if isinstance(module_, AutoWrappedNonRecurseModule):
|
||||
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||
setattr(model, name, module_)
|
||||
break
|
||||
else:
|
||||
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
||||
|
||||
|
||||
def fill_vram_config(model, vram_config):
|
||||
vram_config_ = vram_config.copy()
|
||||
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
|
||||
vram_config_["onload_device"] = vram_config["computation_device"]
|
||||
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
|
||||
vram_config_["preparing_device"] = vram_config["computation_device"]
|
||||
for k in vram_config:
|
||||
if vram_config[k] != vram_config_[k]:
|
||||
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
|
||||
break
|
||||
return vram_config_
|
||||
|
||||
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
|
||||
for source_module, target_module in module_map.items():
|
||||
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
|
||||
if isinstance(model, source_module):
|
||||
vram_config = fill_vram_config(model, vram_config)
|
||||
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||
break
|
||||
else:
|
||||
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
||||
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
|
||||
model.vram_management_enabled = True
|
||||
return model
|
||||
1
diffsynth/data/__init__.py
Normal file
1
diffsynth/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .video import VideoData, save_video, save_frames
|
||||
35
diffsynth/data/simple_text_image.py
Normal file
35
diffsynth/data/simple_text_image.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch, os
|
||||
from torchvision import transforms
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class TextImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
@@ -2,8 +2,6 @@ import imageio, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
|
||||
class LowMemoryVideo:
|
||||
@@ -116,7 +114,7 @@ class VideoData:
|
||||
if self.height is not None and self.width is not None:
|
||||
return self.height, self.width
|
||||
else:
|
||||
width, height = self.__getitem__(0).size
|
||||
height, width, _ = self.__getitem__(0).shape
|
||||
return height, width
|
||||
|
||||
def __getitem__(self, item):
|
||||
@@ -137,8 +135,8 @@ class VideoData:
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||
def save_video(frames, save_path, fps, quality=9):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
@@ -148,70 +146,3 @@ def save_frames(frames, save_path):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
||||
frame.save(os.path.join(save_path, f"{i}.png"))
|
||||
|
||||
|
||||
def merge_video_audio(video_path: str, audio_path: str):
|
||||
# TODO: may need a in-python implementation to avoid subprocess dependency
|
||||
"""
|
||||
Merge the video and audio into a new video, with the duration set to the shorter of the two,
|
||||
and overwrite the original video file.
|
||||
|
||||
Parameters:
|
||||
video_path (str): Path to the original video file
|
||||
audio_path (str): Path to the audio file
|
||||
"""
|
||||
|
||||
# check
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"video file {video_path} does not exist")
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"audio file {audio_path} does not exist")
|
||||
|
||||
base, ext = os.path.splitext(video_path)
|
||||
temp_output = f"{base}_temp{ext}"
|
||||
|
||||
try:
|
||||
# create ffmpeg command
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y', # overwrite
|
||||
'-i',
|
||||
video_path,
|
||||
'-i',
|
||||
audio_path,
|
||||
'-c:v',
|
||||
'copy', # copy video stream
|
||||
'-c:a',
|
||||
'aac', # use AAC audio encoder
|
||||
'-b:a',
|
||||
'192k', # set audio bitrate (optional)
|
||||
'-map',
|
||||
'0:v:0', # select the first video stream
|
||||
'-map',
|
||||
'1:a:0', # select the first audio stream
|
||||
'-shortest', # choose the shortest duration
|
||||
temp_output
|
||||
]
|
||||
|
||||
# execute the command
|
||||
result = subprocess.run(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# check result
|
||||
if result.returncode != 0:
|
||||
error_msg = f"FFmpeg execute failed: {result.stderr}"
|
||||
print(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
shutil.move(temp_output, video_path)
|
||||
print(f"Merge completed, saved to {video_path}")
|
||||
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_output):
|
||||
os.remove(temp_output)
|
||||
print(f"merge_video_audio failed with error: {e}")
|
||||
|
||||
|
||||
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
|
||||
save_video(frames, save_path, fps, quality, ffmpeg_params)
|
||||
merge_video_audio(save_path, audio_path)
|
||||
@@ -1,6 +0,0 @@
|
||||
from .flow_match import FlowMatchScheduler
|
||||
from .training_module import DiffusionTrainingModule
|
||||
from .logger import ModelLogger
|
||||
from .runner import launch_training_task, launch_data_process_task
|
||||
from .parsers import *
|
||||
from .loss import *
|
||||
@@ -1,462 +0,0 @@
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import repeat, reduce
|
||||
from typing import Union
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
def __init__(
|
||||
self,
|
||||
seperate_cfg: bool = False,
|
||||
take_over: bool = False,
|
||||
input_params: tuple[str] = None,
|
||||
output_params: tuple[str] = None,
|
||||
input_params_posi: dict[str, str] = None,
|
||||
input_params_nega: dict[str, str] = None,
|
||||
onload_model_names: tuple[str] = None
|
||||
):
|
||||
self.seperate_cfg = seperate_cfg
|
||||
self.take_over = take_over
|
||||
self.input_params = input_params
|
||||
self.output_params = output_params
|
||||
self.input_params_posi = input_params_posi
|
||||
self.input_params_nega = input_params_nega
|
||||
self.onload_model_names = onload_model_names
|
||||
|
||||
def fetch_input_params(self):
|
||||
params = []
|
||||
if self.input_params is not None:
|
||||
for param in self.input_params:
|
||||
params.append(param)
|
||||
if self.input_params_posi is not None:
|
||||
for _, param in self.input_params_posi.items():
|
||||
params.append(param)
|
||||
if self.input_params_nega is not None:
|
||||
for _, param in self.input_params_nega.items():
|
||||
params.append(param)
|
||||
params = sorted(list(set(params)))
|
||||
return params
|
||||
|
||||
def fetch_output_params(self):
|
||||
params = []
|
||||
if self.output_params is not None:
|
||||
for param in self.output_params:
|
||||
params.append(param)
|
||||
return params
|
||||
|
||||
def process(self, pipe, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
def post_process(self, pipe, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device=get_device_type(), torch_dtype=torch.float16,
|
||||
height_division_factor=64, width_division_factor=64,
|
||||
time_division_factor=None, time_division_remainder=None,
|
||||
):
|
||||
super().__init__()
|
||||
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device_type = parse_device_type(device)
|
||||
# The following parameters are used for shape check.
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.time_division_factor = time_division_factor
|
||||
self.time_division_remainder = time_division_remainder
|
||||
# VRAM management
|
||||
self.vram_management_enabled = False
|
||||
# Pipeline Unit Runner
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
# LoRA Loader
|
||||
self.lora_loader = GeneralLoRALoader
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
if verbose > 0:
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
if verbose > 0:
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
if verbose > 0:
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a PIL.Image to torch.Tensor
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32))
|
||||
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
image = image * ((max_value - min_value) / 255) + min_value
|
||||
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a list of PIL.Image to torch.Tensor
|
||||
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
||||
video = torch.stack(video, dim=pattern.index("T") // 2)
|
||||
return video
|
||||
|
||||
|
||||
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to PIL.Image
|
||||
if pattern != "H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
||||
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
||||
image = image.to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(image.numpy())
|
||||
return image
|
||||
|
||||
|
||||
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
||||
# Transform a torch.Tensor to list of PIL.Image
|
||||
if pattern != "T H W C":
|
||||
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
||||
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
||||
return video
|
||||
|
||||
|
||||
def load_models_to_device(self, model_names):
|
||||
if self.vram_management_enabled:
|
||||
# offload models
|
||||
for name, model in self.named_children():
|
||||
if name not in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
if hasattr(model, "offload"):
|
||||
model.offload()
|
||||
else:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
getattr(torch, self.device_type).empty_cache()
|
||||
# onload models
|
||||
for name, model in self.named_children():
|
||||
if name in model_names:
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
if hasattr(model, "onload"):
|
||||
model.onload()
|
||||
else:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
||||
# Initialize Gaussian noise
|
||||
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
||||
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
||||
return noise
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
if "." in name:
|
||||
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
||||
if name.isdigit():
|
||||
return self.get_module(model[int(name)], suffix)
|
||||
else:
|
||||
return self.get_module(getattr(model, name), suffix)
|
||||
else:
|
||||
return getattr(model, name)
|
||||
|
||||
def freeze_except(self, model_names):
|
||||
self.eval()
|
||||
self.requires_grad_(False)
|
||||
for name in model_names:
|
||||
module = self.get_module(self, name)
|
||||
if module is None:
|
||||
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
||||
continue
|
||||
module.train()
|
||||
module.requires_grad_(True)
|
||||
|
||||
|
||||
def blend_with_mask(self, base, addition, mask):
|
||||
return base * (1 - mask) + addition * mask
|
||||
|
||||
|
||||
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||
timestep = scheduler.timesteps[progress_id]
|
||||
if inpaint_mask is not None:
|
||||
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||
return latents_next
|
||||
|
||||
|
||||
def split_pipeline_units(self, model_names: list[str]):
|
||||
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
|
||||
|
||||
|
||||
def flush_vram_management_device(self, device):
|
||||
for module in self.modules():
|
||||
if isinstance(module, AutoTorchModule):
|
||||
module.offload_device = device
|
||||
module.onload_device = device
|
||||
module.preparing_device = device
|
||||
module.computation_device = device
|
||||
|
||||
|
||||
def load_lora(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=None,
|
||||
state_dict=None,
|
||||
verbose=1,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
lora = state_dict
|
||||
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = lora_loader.convert_state_dict(lora)
|
||||
if hotload is None:
|
||||
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
||||
if hotload:
|
||||
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
||||
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
||||
updated_num = 0
|
||||
for _, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
name = module.name
|
||||
lora_a_name = f'{name}.lora_A.weight'
|
||||
lora_b_name = f'{name}.lora_B.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
updated_num += 1
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
if verbose >= 1:
|
||||
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||
else:
|
||||
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self, verbose=1):
|
||||
cleared_num = 0
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
if hasattr(module, "lora_A_weights"):
|
||||
if len(module.lora_A_weights) > 0:
|
||||
cleared_num += 1
|
||||
module.lora_A_weights.clear()
|
||||
if hasattr(module, "lora_B_weights"):
|
||||
module.lora_B_weights.clear()
|
||||
if verbose >= 1:
|
||||
print(f"{cleared_num} LoRA layers are cleared.")
|
||||
|
||||
|
||||
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||
model_pool = ModelPool()
|
||||
for model_config in model_configs:
|
||||
model_config.download_if_necessary()
|
||||
vram_config = model_config.vram_config()
|
||||
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
|
||||
vram_config["computation_device"] = vram_config["computation_device"] or self.device
|
||||
model_pool.auto_load_model(
|
||||
model_config.path,
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
state_dict=model_config.state_dict,
|
||||
)
|
||||
return model_pool
|
||||
|
||||
|
||||
def check_vram_management_state(self):
|
||||
vram_management_enabled = False
|
||||
for module in self.children():
|
||||
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
||||
vram_management_enabled = True
|
||||
return vram_management_enabled
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
# Separately handling different output types of latents, eg. video and audio latents.
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
|
||||
class PipelineUnitGraph:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def build_edges(self, units: list[PipelineUnit]):
|
||||
# Establish dependencies between units
|
||||
# to search for subsequent related computation units.
|
||||
last_compute_unit_id = {}
|
||||
edges = []
|
||||
for unit_id, unit in enumerate(units):
|
||||
for input_param in unit.fetch_input_params():
|
||||
if input_param in last_compute_unit_id:
|
||||
edges.append((last_compute_unit_id[input_param], unit_id))
|
||||
for output_param in unit.fetch_output_params():
|
||||
last_compute_unit_id[output_param] = unit_id
|
||||
return edges
|
||||
|
||||
def build_chains(self, units: list[PipelineUnit]):
|
||||
# Establish updating chains for each variable
|
||||
# to track their computation process.
|
||||
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
|
||||
params = sorted(list(set(params)))
|
||||
chains = {param: [] for param in params}
|
||||
for unit_id, unit in enumerate(units):
|
||||
for param in unit.fetch_output_params():
|
||||
chains[param].append(unit_id)
|
||||
return chains
|
||||
|
||||
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
|
||||
# Search for units that directly participate in the model's computation.
|
||||
related_unit_ids = []
|
||||
for unit_id, unit in enumerate(units):
|
||||
for model_name in model_names:
|
||||
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
|
||||
related_unit_ids.append(unit_id)
|
||||
break
|
||||
return related_unit_ids
|
||||
|
||||
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
|
||||
# Search for subsequent related computation units.
|
||||
related_unit_ids = [unit_id for unit_id in start_unit_ids]
|
||||
while True:
|
||||
neighbors = []
|
||||
for source, target in edges:
|
||||
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
|
||||
neighbors.append(target)
|
||||
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
|
||||
neighbors.append(source)
|
||||
neighbors = sorted(list(set(neighbors)))
|
||||
if len(neighbors) == 0:
|
||||
break
|
||||
else:
|
||||
related_unit_ids.extend(neighbors)
|
||||
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||
return related_unit_ids
|
||||
|
||||
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
|
||||
# If the input parameters of this subgraph are updated outside the subgraph,
|
||||
# search for the units where these updates occur.
|
||||
first_compute_unit_id = {}
|
||||
for unit_id in related_unit_ids:
|
||||
for param in units[unit_id].fetch_input_params():
|
||||
if param not in first_compute_unit_id:
|
||||
first_compute_unit_id[param] = unit_id
|
||||
updating_unit_ids = []
|
||||
for param in first_compute_unit_id:
|
||||
unit_id = first_compute_unit_id[param]
|
||||
chain = chains[param]
|
||||
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
|
||||
for unit_id_ in chain[chain.index(unit_id) + 1:]:
|
||||
if unit_id_ not in related_unit_ids:
|
||||
updating_unit_ids.append(unit_id_)
|
||||
related_unit_ids.extend(updating_unit_ids)
|
||||
related_unit_ids = sorted(list(set(related_unit_ids)))
|
||||
return related_unit_ids
|
||||
|
||||
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
|
||||
# Split the computation graph,
|
||||
# separating all model-related computations.
|
||||
related_unit_ids = self.search_direct_unit_ids(units, model_names)
|
||||
edges = self.build_edges(units)
|
||||
chains = self.build_chains(units)
|
||||
while True:
|
||||
num_related_unit_ids = len(related_unit_ids)
|
||||
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
|
||||
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
|
||||
if len(related_unit_ids) == num_related_unit_ids:
|
||||
break
|
||||
else:
|
||||
num_related_unit_ids = len(related_unit_ids)
|
||||
related_units = [units[i] for i in related_unit_ids]
|
||||
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
|
||||
return related_units, unrelated_units
|
||||
|
||||
|
||||
class PipelineUnitRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
||||
if unit.take_over:
|
||||
# Let the pipeline unit take over this function.
|
||||
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
||||
elif unit.seperate_cfg:
|
||||
# Positive side
|
||||
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_posi.update(processor_outputs)
|
||||
# Negative side
|
||||
if inputs_shared["cfg_scale"] != 1:
|
||||
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
||||
if unit.input_params is not None:
|
||||
for name in unit.input_params:
|
||||
processor_inputs[name] = inputs_shared.get(name)
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
inputs_nega.update(processor_outputs)
|
||||
else:
|
||||
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
||||
processor_outputs = unit.process(pipe, **processor_inputs)
|
||||
inputs_shared.update(processor_outputs)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -1,236 +0,0 @@
|
||||
import torch, math
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
sigma_min = 0.003/1.002
|
||||
sigma_max = 1.0
|
||||
shift = 3 if shift is None else shift
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
shift = 5 if shift is None else shift
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
shift_terminal = 0.02
|
||||
# Sigmas
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
# Mu
|
||||
if exponential_shift_mu is not None:
|
||||
mu = exponential_shift_mu
|
||||
elif dynamic_shift_len is not None:
|
||||
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
||||
else:
|
||||
mu = 0.8
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
# Shift terminal
|
||||
one_minus_z = 1 - sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
sigmas = 1 - (one_minus_z / scale_factor)
|
||||
# Timesteps
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
base_shift = math.log(3)
|
||||
max_shift = math.log(3)
|
||||
# Sigmas
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
# Mu
|
||||
if exponential_shift_mu is not None:
|
||||
mu = exponential_shift_mu
|
||||
elif dynamic_shift_len is not None:
|
||||
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
||||
else:
|
||||
mu = 0.8
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
# Timesteps
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def compute_empirical_mu(image_seq_len, num_steps):
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
||||
sigma_min = 1 / num_inference_steps
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||
if dynamic_shift_len is None:
|
||||
# If you ask me why I set mu=0.8,
|
||||
# I can only say that it yields better training results.
|
||||
mu = 0.8
|
||||
else:
|
||||
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
shift = 3 if shift is None else shift
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
if target_timesteps is not None:
|
||||
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
||||
for timestep in target_timesteps:
|
||||
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||
num_train_timesteps = 1000
|
||||
if special_case == "stage2":
|
||||
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
||||
elif special_case == "ditilled_stage1":
|
||||
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
||||
else:
|
||||
dynamic_shift_len = dynamic_shift_len or 4096
|
||||
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||
image_seq_len=dynamic_shift_len,
|
||||
base_seq_len=1024,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
)
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||
# Shift terminal
|
||||
one_minus_z = 1.0 - sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||
if len(self.timesteps) != 1000:
|
||||
# This is an empirical formula.
|
||||
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||
num_inference_steps=num_inference_steps,
|
||||
denoising_strength=denoising_strength,
|
||||
**kwargs,
|
||||
)
|
||||
if training:
|
||||
self.set_training_weight()
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
model_output = (sample - sample_stablized) / sigma
|
||||
return model_output
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
|
||||
def training_weight(self, timestep):
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
weights = self.linear_timesteps_weights[timestep_id]
|
||||
return weights
|
||||
@@ -1,43 +0,0 @@
|
||||
import os, torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
self.state_dict_converter = state_dict_converter
|
||||
self.num_steps = 0
|
||||
|
||||
|
||||
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, file_name)
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
@@ -1,158 +0,0 @@
|
||||
from .base_pipeline import BasePipeline
|
||||
import torch
|
||||
|
||||
|
||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
if "first_frame_latents" in inputs:
|
||||
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||
|
||||
if "first_frame_latents" in inputs:
|
||||
noise_pred = noise_pred[:, :, 1:]
|
||||
training_target = training_target[:, :, 1:]
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
# video
|
||||
noise = torch.randn_like(inputs["input_latents"])
|
||||
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
# audio
|
||||
if inputs.get("audio_input_latents") is not None:
|
||||
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
||||
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
||||
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
||||
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
if inputs.get("audio_input_latents") is not None:
|
||||
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
||||
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
||||
loss = loss + loss_audio
|
||||
return loss
|
||||
|
||||
|
||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||
pipe.scheduler.training = True
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
||||
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
||||
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
||||
return loss
|
||||
|
||||
|
||||
class TrajectoryImitationLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, device):
|
||||
import lpips # TODO: remove it
|
||||
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
||||
self.initialized = True
|
||||
|
||||
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
trajectory = [inputs_shared["latents"].clone()]
|
||||
|
||||
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||
|
||||
trajectory.append(inputs_shared["latents"].clone())
|
||||
return pipe.scheduler.timesteps, trajectory
|
||||
|
||||
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
loss = 0
|
||||
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
||||
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
||||
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
|
||||
sigma = pipe.scheduler.sigmas[progress_id]
|
||||
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
||||
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
||||
latents_ = trajectory_teacher[-1]
|
||||
else:
|
||||
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
||||
latents_ = trajectory_teacher[progress_id_teacher]
|
||||
|
||||
denom = sigma_ - sigma
|
||||
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
|
||||
target = (latents_ - inputs_shared["latents"]) / denom
|
||||
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
||||
inputs_shared["latents"] = trajectory_teacher[0]
|
||||
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
noise_pred = pipe.cfg_guided_model_fn(
|
||||
pipe.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
||||
|
||||
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
||||
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
||||
loss = self.loss_fn(image_pred.float(), image_real.float())
|
||||
return loss
|
||||
|
||||
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if not self.initialized:
|
||||
self.initialize(pipe.device)
|
||||
with torch.no_grad():
|
||||
pipe.scheduler.set_timesteps(8)
|
||||
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
||||
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
||||
loss = loss_1 + loss_2
|
||||
return loss
|
||||
@@ -1,70 +0,0 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def add_dataset_base_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||
return parser
|
||||
|
||||
def add_image_size_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||
return parser
|
||||
|
||||
def add_video_size_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||
return parser
|
||||
|
||||
def add_model_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
||||
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
||||
return parser
|
||||
|
||||
def add_training_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||
return parser
|
||||
|
||||
def add_output_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
||||
return parser
|
||||
|
||||
def add_lora_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
||||
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
|
||||
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
|
||||
return parser
|
||||
|
||||
def add_gradient_config(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
||||
return parser
|
||||
|
||||
def add_general_config(parser: argparse.ArgumentParser):
|
||||
parser = add_dataset_base_config(parser)
|
||||
parser = add_model_config(parser)
|
||||
parser = add_training_config(parser)
|
||||
parser = add_output_config(parser)
|
||||
parser = add_lora_config(parser)
|
||||
parser = add_gradient_config(parser)
|
||||
return parser
|
||||
@@ -1,72 +0,0 @@
|
||||
import os, torch
|
||||
from tqdm import tqdm
|
||||
from accelerate import Accelerator
|
||||
from .training_module import DiffusionTrainingModule
|
||||
from .logger import ModelLogger
|
||||
|
||||
|
||||
def launch_training_task(
|
||||
accelerator: Accelerator,
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
learning_rate: float = 1e-5,
|
||||
weight_decay: float = 1e-2,
|
||||
num_workers: int = 1,
|
||||
save_steps: int = None,
|
||||
num_epochs: int = 1,
|
||||
args = None,
|
||||
):
|
||||
if args is not None:
|
||||
learning_rate = args.learning_rate
|
||||
weight_decay = args.weight_decay
|
||||
num_workers = args.dataset_num_workers
|
||||
save_steps = args.save_steps
|
||||
num_epochs = args.num_epochs
|
||||
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
model.to(device=accelerator.device)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data)
|
||||
else:
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||
scheduler.step()
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
|
||||
def launch_data_process_task(
|
||||
accelerator: Accelerator,
|
||||
dataset: torch.utils.data.Dataset,
|
||||
model: DiffusionTrainingModule,
|
||||
model_logger: ModelLogger,
|
||||
num_workers: int = 8,
|
||||
args = None,
|
||||
):
|
||||
if args is not None:
|
||||
num_workers = args.dataset_num_workers
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||
model.to(device=accelerator.device)
|
||||
model, dataloader = accelerator.prepare(model, dataloader)
|
||||
|
||||
for data_id, data in enumerate(tqdm(dataloader)):
|
||||
with accelerator.accumulate(model):
|
||||
with torch.no_grad():
|
||||
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||
data = model(data)
|
||||
torch.save(data, save_path)
|
||||
@@ -1,263 +0,0 @@
|
||||
import torch, json, os
|
||||
from ..core import ModelConfig, load_state_dict
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
|
||||
|
||||
class DiffusionTrainingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for name, model in self.named_children():
|
||||
model.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
||||
return trainable_modules
|
||||
|
||||
|
||||
def trainable_param_names(self):
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
return trainable_param_names
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
||||
if lora_alpha is None:
|
||||
lora_alpha = lora_rank
|
||||
if isinstance(target_modules, list) and len(target_modules) == 1:
|
||||
target_modules = target_modules[0]
|
||||
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
if upcast_dtype is not None:
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
param.data = param.to(upcast_dtype)
|
||||
return model
|
||||
|
||||
|
||||
def mapping_lora_state_dict(self, state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||
new_state_dict[new_key] = value
|
||||
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
||||
trainable_param_names = self.trainable_param_names()
|
||||
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
||||
if remove_prefix is not None:
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith(remove_prefix):
|
||||
name = name[len(remove_prefix):]
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
return state_dict
|
||||
|
||||
|
||||
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||
if data is None:
|
||||
return data
|
||||
elif isinstance(data, torch.Tensor):
|
||||
data = data.to(device)
|
||||
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||
data = data.to(torch_float_dtype)
|
||||
return data
|
||||
elif isinstance(data, tuple):
|
||||
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||
return data
|
||||
elif isinstance(data, list):
|
||||
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
|
||||
return data
|
||||
else:
|
||||
return data
|
||||
|
||||
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
||||
if fp8:
|
||||
return {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": device,
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": device,
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
}
|
||||
elif offload:
|
||||
return {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
"clear_parameters": True,
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
||||
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
||||
offload_models = [] if offload_models is None else offload_models.split(",")
|
||||
model_configs = []
|
||||
if model_paths is not None:
|
||||
model_paths = json.loads(model_paths)
|
||||
for path in model_paths:
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=path in fp8_models,
|
||||
offload=path in offload_models,
|
||||
device=device
|
||||
)
|
||||
model_configs.append(ModelConfig(path=path, **vram_config))
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=model_id_with_origin_path in fp8_models,
|
||||
offload=model_id_with_origin_path in offload_models,
|
||||
device=device
|
||||
)
|
||||
config = self.parse_path_or_model_id(model_id_with_origin_path)
|
||||
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
|
||||
return model_configs
|
||||
|
||||
|
||||
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
|
||||
if model_id_with_origin_path is None:
|
||||
return default_value
|
||||
elif os.path.exists(model_id_with_origin_path):
|
||||
return ModelConfig(path=model_id_with_origin_path)
|
||||
else:
|
||||
if ":" not in model_id_with_origin_path:
|
||||
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
|
||||
split_id = model_id_with_origin_path.rfind(":")
|
||||
model_id = model_id_with_origin_path[:split_id]
|
||||
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||
|
||||
|
||||
def auto_detect_lora_target_modules(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
search_for_linear=False,
|
||||
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
||||
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
||||
name_prefix="",
|
||||
):
|
||||
lora_target_modules = []
|
||||
if search_for_linear:
|
||||
for name, module in model.named_modules():
|
||||
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
||||
lora_target_modules.append(module_name)
|
||||
else:
|
||||
for name, module in model.named_children():
|
||||
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||
lora_target_modules += self.auto_detect_lora_target_modules(
|
||||
module,
|
||||
search_for_linear=block_list_detector(module),
|
||||
linear_detector=linear_detector,
|
||||
block_list_detector=block_list_detector,
|
||||
name_prefix=module_name,
|
||||
)
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def parse_lora_target_modules(self, model, lora_target_modules):
|
||||
if lora_target_modules == "":
|
||||
print("No LoRA target modules specified. The framework will automatically search for them.")
|
||||
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
||||
print(f"LoRA will be patched at {lora_target_modules}.")
|
||||
else:
|
||||
lora_target_modules = lora_target_modules.split(",")
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
pipe,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
task="sft",
|
||||
):
|
||||
# Scheduler
|
||||
pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
# Freeze untrainable models
|
||||
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
||||
|
||||
# Preset LoRA
|
||||
if preset_lora_path is not None:
|
||||
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
|
||||
|
||||
# FP8
|
||||
# FP8 relies on a model-specific memory management scheme.
|
||||
# It is delegated to the subclass.
|
||||
|
||||
# Add LoRA to the base models
|
||||
if lora_base_model is not None and not task.endswith(":data_process"):
|
||||
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
||||
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
||||
return
|
||||
model = self.add_lora_to_model(
|
||||
getattr(pipe, lora_base_model),
|
||||
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=pipe.torch_dtype,
|
||||
)
|
||||
if lora_checkpoint is not None:
|
||||
state_dict = load_state_dict(lora_checkpoint)
|
||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||
load_result = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||
if len(load_result[1]) > 0:
|
||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||
setattr(pipe, lora_base_model, model)
|
||||
|
||||
|
||||
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
||||
models_require_backward = []
|
||||
if trainable_models is not None:
|
||||
models_require_backward += trainable_models.split(",")
|
||||
if lora_base_model is not None:
|
||||
models_require_backward += [lora_base_model]
|
||||
if task.endswith(":data_process"):
|
||||
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
||||
elif task.endswith(":train"):
|
||||
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
||||
return pipe
|
||||
|
||||
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||
controlnet_keys_map = (
|
||||
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
|
||||
("controlnet_", "controlnet_inputs"),
|
||||
)
|
||||
controlnet_inputs = {}
|
||||
for extra_input in extra_inputs:
|
||||
for prefix, name in controlnet_keys_map:
|
||||
if extra_input.startswith(prefix):
|
||||
if name not in controlnet_inputs:
|
||||
controlnet_inputs[name] = {}
|
||||
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
||||
break
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
for name, params in controlnet_inputs.items():
|
||||
inputs_shared[name] = [ControlNetInput(**params)]
|
||||
return inputs_shared
|
||||
129
diffsynth/extensions/ESRGAN/__init__.py
Normal file
129
diffsynth/extensions/ESRGAN/__init__.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ResidualDenseBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_feat=64, num_grow_ch=32):
|
||||
super(ResidualDenseBlock, self).__init__()
|
||||
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
||||
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_feat, num_grow_ch=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
||||
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
# upsample
|
||||
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
feat = x
|
||||
feat = self.conv_first(feat)
|
||||
body_feat = self.conv_body(self.body(feat))
|
||||
feat = feat + body_feat
|
||||
# upsample
|
||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
||||
feat = self.lrelu(self.conv_up1(feat))
|
||||
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
||||
feat = self.lrelu(self.conv_up2(feat))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return RRDBNetStateDictConverter()
|
||||
|
||||
|
||||
class RRDBNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
|
||||
class ESRGAN(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
return ESRGAN(model_manager.fetch_model("esrgan"))
|
||||
|
||||
def process_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
||||
return image
|
||||
|
||||
def process_images(self, images):
|
||||
images = [self.process_image(image) for image in images]
|
||||
images = torch.stack(images)
|
||||
return images
|
||||
|
||||
def decode_images(self, images):
|
||||
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
||||
images = [Image.fromarray(image) for image in images]
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
||||
# Preprocess
|
||||
input_tensor = self.process_images(images)
|
||||
|
||||
# Interpolate
|
||||
output_tensor = []
|
||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
||||
batch_input_tensor = batch_input_tensor.to(
|
||||
device=self.model.conv_first.weight.device,
|
||||
dtype=self.model.conv_first.weight.dtype)
|
||||
batch_output_tensor = self.model(batch_input_tensor)
|
||||
output_tensor.append(batch_output_tensor.cpu())
|
||||
|
||||
# Output
|
||||
output_tensor = torch.concat(output_tensor, dim=0)
|
||||
|
||||
# To images
|
||||
output_images = self.decode_images(output_tensor)
|
||||
return output_images
|
||||
63
diffsynth/extensions/FastBlend/__init__.py
Normal file
63
diffsynth/extensions/FastBlend/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from .runners.fast import TableManager, PyramidPatchMatcher
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
|
||||
|
||||
class FastBlendSmoother:
|
||||
def __init__(self):
|
||||
self.batch_size = 8
|
||||
self.window_size = 64
|
||||
self.ebsynth_config = {
|
||||
"minimum_patch_size": 5,
|
||||
"threads_per_block": 8,
|
||||
"num_iter": 5,
|
||||
"gpu_id": 0,
|
||||
"guide_weight": 10.0,
|
||||
"initialize": "identity",
|
||||
"tracking_window_size": 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
# TODO: fetch GPU ID from model_manager
|
||||
return FastBlendSmoother()
|
||||
|
||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
||||
frames_guide = [np.array(frame) for frame in frames_guide]
|
||||
frames_style = [np.array(frame) for frame in frames_style]
|
||||
table_manager = TableManager()
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
**ebsynth_config
|
||||
)
|
||||
# left part
|
||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
||||
# right part
|
||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
||||
# merge
|
||||
frames = []
|
||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
||||
weight_m = -1
|
||||
weight = weight_l + weight_m + weight_r
|
||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
||||
frames.append(frame)
|
||||
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
||||
return frames
|
||||
|
||||
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
||||
frames = self.run(
|
||||
original_frames, rendered_frames,
|
||||
self.batch_size, self.window_size, self.ebsynth_config
|
||||
)
|
||||
mempool = cp.get_default_memory_pool()
|
||||
pinned_mempool = cp.get_default_pinned_memory_pool()
|
||||
mempool.free_all_blocks()
|
||||
pinned_mempool.free_all_blocks()
|
||||
return frames
|
||||
397
diffsynth/extensions/FastBlend/api.py
Normal file
397
diffsynth/extensions/FastBlend/api.py
Normal file
@@ -0,0 +1,397 @@
|
||||
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
||||
from .data import VideoData, get_video_fps, save_video, search_for_images
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
||||
frames_guide = VideoData(video_guide, video_guide_folder)
|
||||
frames_style = VideoData(video_style, video_style_folder)
|
||||
message = ""
|
||||
if len(frames_guide) < len(frames_style):
|
||||
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
||||
frames_style.set_length(len(frames_guide))
|
||||
elif len(frames_guide) > len(frames_style):
|
||||
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
||||
frames_guide.set_length(len(frames_style))
|
||||
height_guide, width_guide = frames_guide.shape()
|
||||
height_style, width_style = frames_style.shape()
|
||||
if height_guide != height_style or width_guide != width_style:
|
||||
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
||||
frames_style.set_shape(height_guide, width_guide)
|
||||
return frames_guide, frames_style, message
|
||||
|
||||
|
||||
def smooth_video(
|
||||
video_guide,
|
||||
video_guide_folder,
|
||||
video_style,
|
||||
video_style_folder,
|
||||
mode,
|
||||
window_size,
|
||||
batch_size,
|
||||
tracking_window_size,
|
||||
output_path,
|
||||
fps,
|
||||
minimum_patch_size,
|
||||
num_iter,
|
||||
guide_weight,
|
||||
initialize,
|
||||
progress = None,
|
||||
):
|
||||
# input
|
||||
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
||||
if len(message) > 0:
|
||||
print(message)
|
||||
# output
|
||||
if output_path == "":
|
||||
if video_style is None:
|
||||
output_path = os.path.join(video_style_folder, "output")
|
||||
else:
|
||||
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
print("No valid output_path. Your video will be saved here:", output_path)
|
||||
elif not os.path.exists(output_path):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
print("Your video will be saved here:", output_path)
|
||||
frames_path = os.path.join(output_path, "frames")
|
||||
video_path = os.path.join(output_path, "video.mp4")
|
||||
os.makedirs(frames_path, exist_ok=True)
|
||||
# process
|
||||
if mode == "Fast" or mode == "Balanced":
|
||||
tracking_window_size = 0
|
||||
ebsynth_config = {
|
||||
"minimum_patch_size": minimum_patch_size,
|
||||
"threads_per_block": 8,
|
||||
"num_iter": num_iter,
|
||||
"gpu_id": 0,
|
||||
"guide_weight": guide_weight,
|
||||
"initialize": initialize,
|
||||
"tracking_window_size": tracking_window_size,
|
||||
}
|
||||
if mode == "Fast":
|
||||
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
||||
elif mode == "Balanced":
|
||||
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
||||
elif mode == "Accurate":
|
||||
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
||||
# output
|
||||
try:
|
||||
fps = int(fps)
|
||||
except:
|
||||
fps = get_video_fps(video_style) if video_style is not None else 30
|
||||
print("Fps:", fps)
|
||||
print("Saving video...")
|
||||
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
||||
print("Success!")
|
||||
print("Your frames are here:", frames_path)
|
||||
print("Your video is here:", video_path)
|
||||
return output_path, fps, video_path
|
||||
|
||||
|
||||
class KeyFrameMatcher:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def extract_number_from_filename(self, file_name):
|
||||
result = []
|
||||
number = -1
|
||||
for i in file_name:
|
||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
||||
if number == -1:
|
||||
number = 0
|
||||
number = number*10 + ord(i) - ord("0")
|
||||
else:
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
number = -1
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
def extract_number_from_filenames(self, file_names):
|
||||
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
||||
min_length = min(len(i) for i in numbers)
|
||||
for i in range(min_length-1, -1, -1):
|
||||
if len(set(number[i] for number in numbers))==len(file_names):
|
||||
return [number[i] for number in numbers]
|
||||
return list(range(len(file_names)))
|
||||
|
||||
def match_using_filename(self, file_names_a, file_names_b):
|
||||
file_names_b_set = set(file_names_b)
|
||||
matched_file_name = []
|
||||
for file_name in file_names_a:
|
||||
if file_name not in file_names_b_set:
|
||||
matched_file_name.append(None)
|
||||
else:
|
||||
matched_file_name.append(file_name)
|
||||
return matched_file_name
|
||||
|
||||
def match_using_numbers(self, file_names_a, file_names_b):
|
||||
numbers_a = self.extract_number_from_filenames(file_names_a)
|
||||
numbers_b = self.extract_number_from_filenames(file_names_b)
|
||||
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
||||
matched_file_name = []
|
||||
for number in numbers_a:
|
||||
if number in numbers_b_dict:
|
||||
matched_file_name.append(numbers_b_dict[number])
|
||||
else:
|
||||
matched_file_name.append(None)
|
||||
return matched_file_name
|
||||
|
||||
def match_filenames(self, file_names_a, file_names_b):
|
||||
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
||||
if sum([i is not None for i in matched_file_name]) > 0:
|
||||
return matched_file_name
|
||||
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
||||
return matched_file_name
|
||||
|
||||
|
||||
def detect_frames(frames_path, keyframes_path):
|
||||
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
||||
return "Please input the directory of guide video and rendered frames"
|
||||
elif not os.path.exists(frames_path):
|
||||
return "Please input the directory of guide video"
|
||||
elif not os.path.exists(keyframes_path):
|
||||
return "Please input the directory of rendered frames"
|
||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
||||
if len(frames)==0:
|
||||
return f"No images detected in {frames_path}"
|
||||
if len(keyframes)==0:
|
||||
return f"No images detected in {keyframes_path}"
|
||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
||||
max_filename_length = max([len(i) for i in frames])
|
||||
if sum([i is not None for i in matched_keyframes])==0:
|
||||
message = ""
|
||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
||||
message += "--> No matched keyframes\n"
|
||||
else:
|
||||
message = ""
|
||||
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
||||
message += frame + " " * (max_filename_length - len(frame) + 1)
|
||||
if matched_keyframe is None:
|
||||
message += "--> [to be rendered]\n"
|
||||
else:
|
||||
message += f"--> {matched_keyframe}\n"
|
||||
return message
|
||||
|
||||
|
||||
def check_input_for_interpolating(frames_path, keyframes_path):
|
||||
# search for images
|
||||
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
||||
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
||||
# match frames
|
||||
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
||||
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
||||
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
||||
frames_guide = VideoData(None, frames_path)
|
||||
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
||||
# match shape
|
||||
message = ""
|
||||
height_guide, width_guide = frames_guide.shape()
|
||||
height_style, width_style = frames_style.shape()
|
||||
if height_guide != height_style or width_guide != width_style:
|
||||
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
||||
frames_style.set_shape(height_guide, width_guide)
|
||||
return frames_guide, frames_style, index_style, message
|
||||
|
||||
|
||||
def interpolate_video(
|
||||
frames_path,
|
||||
keyframes_path,
|
||||
output_path,
|
||||
fps,
|
||||
batch_size,
|
||||
tracking_window_size,
|
||||
minimum_patch_size,
|
||||
num_iter,
|
||||
guide_weight,
|
||||
initialize,
|
||||
progress = None,
|
||||
):
|
||||
# input
|
||||
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
||||
if len(message) > 0:
|
||||
print(message)
|
||||
# output
|
||||
if output_path == "":
|
||||
output_path = os.path.join(keyframes_path, "output")
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
print("No valid output_path. Your video will be saved here:", output_path)
|
||||
elif not os.path.exists(output_path):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
print("Your video will be saved here:", output_path)
|
||||
output_frames_path = os.path.join(output_path, "frames")
|
||||
output_video_path = os.path.join(output_path, "video.mp4")
|
||||
os.makedirs(output_frames_path, exist_ok=True)
|
||||
# process
|
||||
ebsynth_config = {
|
||||
"minimum_patch_size": minimum_patch_size,
|
||||
"threads_per_block": 8,
|
||||
"num_iter": num_iter,
|
||||
"gpu_id": 0,
|
||||
"guide_weight": guide_weight,
|
||||
"initialize": initialize,
|
||||
"tracking_window_size": tracking_window_size
|
||||
}
|
||||
if len(index_style)==1:
|
||||
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
||||
else:
|
||||
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
||||
try:
|
||||
fps = int(fps)
|
||||
except:
|
||||
fps = 30
|
||||
print("Fps:", fps)
|
||||
print("Saving video...")
|
||||
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
||||
print("Success!")
|
||||
print("Your frames are here:", output_frames_path)
|
||||
print("Your video is here:", video_path)
|
||||
return output_path, fps, video_path
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Blocks(analytics_enabled=False) as ui_component:
|
||||
with gr.Tab("Blend"):
|
||||
gr.Markdown("""
|
||||
# Blend
|
||||
|
||||
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Tab("Guide video"):
|
||||
video_guide = gr.Video(label="Guide video")
|
||||
with gr.Tab("Guide video (images format)"):
|
||||
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
||||
with gr.Column():
|
||||
with gr.Tab("Style video"):
|
||||
video_style = gr.Video(label="Style video")
|
||||
with gr.Tab("Style video (images format)"):
|
||||
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
||||
with gr.Column():
|
||||
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
||||
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
||||
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
||||
btn = gr.Button(value="Blend")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown("# Settings")
|
||||
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
||||
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
||||
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
||||
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
||||
gr.Markdown("## Advanced Settings")
|
||||
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
||||
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
||||
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
||||
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
||||
with gr.Column():
|
||||
gr.Markdown("""
|
||||
# Reference
|
||||
|
||||
* Output directory: the directory to save the video.
|
||||
* Inference mode
|
||||
|
||||
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
||||
|-|-|-|-|-|-|
|
||||
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
||||
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
||||
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
||||
|
||||
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
||||
* Advanced settings
|
||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
||||
""")
|
||||
btn.click(
|
||||
smooth_video,
|
||||
inputs=[
|
||||
video_guide,
|
||||
video_guide_folder,
|
||||
video_style,
|
||||
video_style_folder,
|
||||
mode,
|
||||
window_size,
|
||||
batch_size,
|
||||
tracking_window_size,
|
||||
output_path,
|
||||
fps,
|
||||
minimum_patch_size,
|
||||
num_iter,
|
||||
guide_weight,
|
||||
initialize
|
||||
],
|
||||
outputs=[output_path, fps, video_output]
|
||||
)
|
||||
with gr.Tab("Interpolate"):
|
||||
gr.Markdown("""
|
||||
# Interpolate
|
||||
|
||||
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
||||
with gr.Column():
|
||||
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
||||
with gr.Row():
|
||||
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
||||
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
||||
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
||||
with gr.Column():
|
||||
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
||||
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
||||
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
||||
btn_ = gr.Button(value="Interpolate")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown("# Settings")
|
||||
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
||||
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
||||
gr.Markdown("## Advanced Settings")
|
||||
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
||||
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
||||
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
||||
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
||||
with gr.Column():
|
||||
gr.Markdown("""
|
||||
# Reference
|
||||
|
||||
* Output directory: the directory to save the video.
|
||||
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
||||
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
||||
* Advanced settings
|
||||
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
||||
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
||||
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
||||
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
||||
""")
|
||||
btn_.click(
|
||||
interpolate_video,
|
||||
inputs=[
|
||||
video_guide_folder_,
|
||||
rendered_keyframes_,
|
||||
output_path_,
|
||||
fps_,
|
||||
batch_size_,
|
||||
tracking_window_size_,
|
||||
minimum_patch_size_,
|
||||
num_iter_,
|
||||
guide_weight_,
|
||||
initialize_,
|
||||
],
|
||||
outputs=[output_path_, fps_, video_output_]
|
||||
)
|
||||
|
||||
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
||||
119
diffsynth/extensions/FastBlend/cupy_kernels.py
Normal file
119
diffsynth/extensions/FastBlend/cupy_kernels.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import cupy as cp
|
||||
|
||||
remapping_kernel = cp.RawKernel(r'''
|
||||
extern "C" __global__
|
||||
void remap(
|
||||
const int height,
|
||||
const int width,
|
||||
const int channel,
|
||||
const int patch_size,
|
||||
const int pad_size,
|
||||
const float* source_style,
|
||||
const int* nnf,
|
||||
float* target_style
|
||||
) {
|
||||
const int r = (patch_size - 1) / 2;
|
||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
if (x >= height or y >= width) return;
|
||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
||||
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
||||
const int min_px = x < r ? -x : -r;
|
||||
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
||||
const int min_py = y < r ? -y : -r;
|
||||
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
||||
int num = 0;
|
||||
for (int px = min_px; px <= max_px; px++){
|
||||
for (int py = min_py; py <= max_py; py++){
|
||||
const int nid = (x + px) * width + y + py;
|
||||
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
||||
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
||||
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
||||
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
||||
num++;
|
||||
for (int c = 0; c < channel; c++){
|
||||
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int c = 0; c < channel; c++){
|
||||
target_style[z + pid * channel + c] /= num;
|
||||
}
|
||||
}
|
||||
''', 'remap')
|
||||
|
||||
|
||||
patch_error_kernel = cp.RawKernel(r'''
|
||||
extern "C" __global__
|
||||
void patch_error(
|
||||
const int height,
|
||||
const int width,
|
||||
const int channel,
|
||||
const int patch_size,
|
||||
const int pad_size,
|
||||
const float* source,
|
||||
const int* nnf,
|
||||
const float* target,
|
||||
float* error
|
||||
) {
|
||||
const int r = (patch_size - 1) / 2;
|
||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
||||
if (x >= height or y >= width) return;
|
||||
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
||||
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
||||
float e = 0;
|
||||
for (int px = -r; px <= r; px++){
|
||||
for (int py = -r; py <= r; py++){
|
||||
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
||||
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
||||
for (int c = 0; c < channel; c++){
|
||||
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
||||
e += diff * diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
error[blockIdx.z * height * width + x * width + y] = e;
|
||||
}
|
||||
''', 'patch_error')
|
||||
|
||||
|
||||
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
||||
extern "C" __global__
|
||||
void pairwise_patch_error(
|
||||
const int height,
|
||||
const int width,
|
||||
const int channel,
|
||||
const int patch_size,
|
||||
const int pad_size,
|
||||
const float* source_a,
|
||||
const int* nnf_a,
|
||||
const float* source_b,
|
||||
const int* nnf_b,
|
||||
float* error
|
||||
) {
|
||||
const int r = (patch_size - 1) / 2;
|
||||
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
||||
if (x >= height or y >= width) return;
|
||||
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
||||
const int x_a = nnf_a[z_nnf + 0];
|
||||
const int y_a = nnf_a[z_nnf + 1];
|
||||
const int x_b = nnf_b[z_nnf + 0];
|
||||
const int y_b = nnf_b[z_nnf + 1];
|
||||
float e = 0;
|
||||
for (int px = -r; px <= r; px++){
|
||||
for (int py = -r; py <= r; py++){
|
||||
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
||||
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
||||
for (int c = 0; c < channel; c++){
|
||||
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
||||
e += diff * diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
error[blockIdx.z * height * width + x * width + y] = e;
|
||||
}
|
||||
''', 'pairwise_patch_error')
|
||||
146
diffsynth/extensions/FastBlend/data.py
Normal file
146
diffsynth/extensions/FastBlend/data.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import imageio, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_video(file_name):
|
||||
reader = imageio.get_reader(file_name)
|
||||
video = []
|
||||
for frame in reader:
|
||||
frame = np.array(frame)
|
||||
video.append(frame)
|
||||
reader.close()
|
||||
return video
|
||||
|
||||
|
||||
def get_video_fps(file_name):
|
||||
reader = imageio.get_reader(file_name)
|
||||
fps = reader.get_meta_data()["fps"]
|
||||
reader.close()
|
||||
return fps
|
||||
|
||||
|
||||
def save_video(frames_path, video_path, num_frames, fps):
|
||||
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
||||
for i in range(num_frames):
|
||||
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
return video_path
|
||||
|
||||
|
||||
class LowMemoryVideo:
|
||||
def __init__(self, file_name):
|
||||
self.reader = imageio.get_reader(file_name)
|
||||
|
||||
def __len__(self):
|
||||
return self.reader.count_frames()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return np.array(self.reader.get_data(item))
|
||||
|
||||
def __del__(self):
|
||||
self.reader.close()
|
||||
|
||||
|
||||
def split_file_name(file_name):
|
||||
result = []
|
||||
number = -1
|
||||
for i in file_name:
|
||||
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
||||
if number == -1:
|
||||
number = 0
|
||||
number = number*10 + ord(i) - ord("0")
|
||||
else:
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
number = -1
|
||||
result.append(i)
|
||||
if number != -1:
|
||||
result.append(number)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
|
||||
def search_for_images(folder):
|
||||
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
||||
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
||||
file_list = [i[1] for i in sorted(file_list)]
|
||||
file_list = [os.path.join(folder, i) for i in file_list]
|
||||
return file_list
|
||||
|
||||
|
||||
def read_images(folder):
|
||||
file_list = search_for_images(folder)
|
||||
frames = [np.array(Image.open(i)) for i in file_list]
|
||||
return frames
|
||||
|
||||
|
||||
class LowMemoryImageFolder:
|
||||
def __init__(self, folder, file_list=None):
|
||||
if file_list is None:
|
||||
self.file_list = search_for_images(folder)
|
||||
else:
|
||||
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return np.array(Image.open(self.file_list[item]))
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
class VideoData:
|
||||
def __init__(self, video_file, image_folder, **kwargs):
|
||||
if video_file is not None:
|
||||
self.data_type = "video"
|
||||
self.data = LowMemoryVideo(video_file, **kwargs)
|
||||
elif image_folder is not None:
|
||||
self.data_type = "images"
|
||||
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
||||
else:
|
||||
raise ValueError("Cannot open video or image folder")
|
||||
self.length = None
|
||||
self.height = None
|
||||
self.width = None
|
||||
|
||||
def raw_data(self):
|
||||
frames = []
|
||||
for i in range(self.__len__()):
|
||||
frames.append(self.__getitem__(i))
|
||||
return frames
|
||||
|
||||
def set_length(self, length):
|
||||
self.length = length
|
||||
|
||||
def set_shape(self, height, width):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
def __len__(self):
|
||||
if self.length is None:
|
||||
return len(self.data)
|
||||
else:
|
||||
return self.length
|
||||
|
||||
def shape(self):
|
||||
if self.height is not None and self.width is not None:
|
||||
return self.height, self.width
|
||||
else:
|
||||
height, width, _ = self.__getitem__(0).shape
|
||||
return height, width
|
||||
|
||||
def __getitem__(self, item):
|
||||
frame = self.data.__getitem__(item)
|
||||
height, width, _ = frame.shape
|
||||
if self.height is not None and self.width is not None:
|
||||
if self.height != height or self.width != width:
|
||||
frame = Image.fromarray(frame).resize((self.width, self.height))
|
||||
frame = np.array(frame)
|
||||
return frame
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
298
diffsynth/extensions/FastBlend/patch_match.py
Normal file
298
diffsynth/extensions/FastBlend/patch_match.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
import cv2
|
||||
|
||||
|
||||
class PatchMatcher:
|
||||
def __init__(
|
||||
self, height, width, channel, minimum_patch_size,
|
||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
||||
random_search_steps=3, random_search_range=4,
|
||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
||||
tracking_window_size=0
|
||||
):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.minimum_patch_size = minimum_patch_size
|
||||
self.threads_per_block = threads_per_block
|
||||
self.num_iter = num_iter
|
||||
self.gpu_id = gpu_id
|
||||
self.guide_weight = guide_weight
|
||||
self.random_search_steps = random_search_steps
|
||||
self.random_search_range = random_search_range
|
||||
self.use_mean_target_style = use_mean_target_style
|
||||
self.use_pairwise_patch_error = use_pairwise_patch_error
|
||||
self.tracking_window_size = tracking_window_size
|
||||
|
||||
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
||||
self.pad_size = self.patch_size_list[0] // 2
|
||||
self.grid = (
|
||||
(height + threads_per_block - 1) // threads_per_block,
|
||||
(width + threads_per_block - 1) // threads_per_block
|
||||
)
|
||||
self.block = (threads_per_block, threads_per_block)
|
||||
|
||||
def pad_image(self, image):
|
||||
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
||||
|
||||
def unpad_image(self, image):
|
||||
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
||||
|
||||
def apply_nnf_to_image(self, nnf, source):
|
||||
batch_size = source.shape[0]
|
||||
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
||||
remapping_kernel(
|
||||
self.grid + (batch_size,),
|
||||
self.block,
|
||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
||||
)
|
||||
return target
|
||||
|
||||
def get_patch_error(self, source, nnf, target):
|
||||
batch_size = source.shape[0]
|
||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
||||
patch_error_kernel(
|
||||
self.grid + (batch_size,),
|
||||
self.block,
|
||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
||||
)
|
||||
return error
|
||||
|
||||
def get_pairwise_patch_error(self, source, nnf):
|
||||
batch_size = source.shape[0]//2
|
||||
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
||||
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
||||
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
||||
pairwise_patch_error_kernel(
|
||||
self.grid + (batch_size,),
|
||||
self.block,
|
||||
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
||||
)
|
||||
error = error.repeat(2, axis=0)
|
||||
return error
|
||||
|
||||
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
||||
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
||||
if self.use_mean_target_style:
|
||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
||||
target_style = target_style.mean(axis=0, keepdims=True)
|
||||
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
||||
if self.use_pairwise_patch_error:
|
||||
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
||||
else:
|
||||
error_style = self.get_patch_error(source_style, nnf, target_style)
|
||||
error = error_guide * self.guide_weight + error_style
|
||||
return error
|
||||
|
||||
def clamp_bound(self, nnf):
|
||||
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
||||
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
||||
return nnf
|
||||
|
||||
def random_step(self, nnf, r):
|
||||
batch_size = nnf.shape[0]
|
||||
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
||||
upd_nnf = self.clamp_bound(nnf + step)
|
||||
return upd_nnf
|
||||
|
||||
def neighboor_step(self, nnf, d):
|
||||
if d==0:
|
||||
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
||||
upd_nnf[:, :, :, 0] += 1
|
||||
elif d==1:
|
||||
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
||||
upd_nnf[:, :, :, 1] += 1
|
||||
elif d==2:
|
||||
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
||||
upd_nnf[:, :, :, 0] -= 1
|
||||
elif d==3:
|
||||
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
||||
upd_nnf[:, :, :, 1] -= 1
|
||||
upd_nnf = self.clamp_bound(upd_nnf)
|
||||
return upd_nnf
|
||||
|
||||
def shift_nnf(self, nnf, d):
|
||||
if d>0:
|
||||
d = min(nnf.shape[0], d)
|
||||
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
||||
else:
|
||||
d = max(-nnf.shape[0], d)
|
||||
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
||||
return upd_nnf
|
||||
|
||||
def track_step(self, nnf, d):
|
||||
if self.use_pairwise_patch_error:
|
||||
upd_nnf = cp.zeros_like(nnf)
|
||||
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
||||
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
||||
else:
|
||||
upd_nnf = self.shift_nnf(nnf, d)
|
||||
return upd_nnf
|
||||
|
||||
def C(self, n, m):
|
||||
# not used
|
||||
c = 1
|
||||
for i in range(1, n+1):
|
||||
c *= i
|
||||
for i in range(1, m+1):
|
||||
c //= i
|
||||
for i in range(1, n-m+1):
|
||||
c //= i
|
||||
return c
|
||||
|
||||
def bezier_step(self, nnf, r):
|
||||
# not used
|
||||
n = r * 2 - 1
|
||||
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
||||
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
||||
if d>0:
|
||||
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
||||
elif d<0:
|
||||
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
||||
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
||||
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
||||
return upd_nnf
|
||||
|
||||
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
||||
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
||||
upd_idx = (upd_err < err)
|
||||
nnf[upd_idx] = upd_nnf[upd_idx]
|
||||
err[upd_idx] = upd_err[upd_idx]
|
||||
return nnf, err
|
||||
|
||||
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
||||
for d in cp.random.permutation(4):
|
||||
upd_nnf = self.neighboor_step(nnf, d)
|
||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
||||
return nnf, err
|
||||
|
||||
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
||||
for i in range(self.random_search_steps):
|
||||
upd_nnf = self.random_step(nnf, self.random_search_range)
|
||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
||||
return nnf, err
|
||||
|
||||
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
||||
for d in range(1, self.tracking_window_size + 1):
|
||||
upd_nnf = self.track_step(nnf, d)
|
||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
||||
upd_nnf = self.track_step(nnf, -d)
|
||||
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
||||
return nnf, err
|
||||
|
||||
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
||||
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
||||
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
||||
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
||||
return nnf, err
|
||||
|
||||
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
||||
with cp.cuda.Device(self.gpu_id):
|
||||
source_guide = self.pad_image(source_guide)
|
||||
target_guide = self.pad_image(target_guide)
|
||||
source_style = self.pad_image(source_style)
|
||||
for it in range(self.num_iter):
|
||||
self.patch_size = self.patch_size_list[it]
|
||||
target_style = self.apply_nnf_to_image(nnf, source_style)
|
||||
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
||||
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
||||
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
||||
return nnf, target_style
|
||||
|
||||
|
||||
class PyramidPatchMatcher:
|
||||
def __init__(
|
||||
self, image_height, image_width, channel, minimum_patch_size,
|
||||
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
||||
use_mean_target_style=False, use_pairwise_patch_error=False,
|
||||
tracking_window_size=0,
|
||||
initialize="identity"
|
||||
):
|
||||
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
||||
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
||||
self.pyramid_heights = []
|
||||
self.pyramid_widths = []
|
||||
self.patch_matchers = []
|
||||
self.minimum_patch_size = minimum_patch_size
|
||||
self.num_iter = num_iter
|
||||
self.gpu_id = gpu_id
|
||||
self.initialize = initialize
|
||||
for level in range(self.pyramid_level):
|
||||
height = image_height//(2**(self.pyramid_level - 1 - level))
|
||||
width = image_width//(2**(self.pyramid_level - 1 - level))
|
||||
self.pyramid_heights.append(height)
|
||||
self.pyramid_widths.append(width)
|
||||
self.patch_matchers.append(PatchMatcher(
|
||||
height, width, channel, minimum_patch_size=minimum_patch_size,
|
||||
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
||||
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
||||
tracking_window_size=tracking_window_size
|
||||
))
|
||||
|
||||
def resample_image(self, images, level):
|
||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||
images = images.get()
|
||||
images_resample = []
|
||||
for image in images:
|
||||
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
||||
images_resample.append(image_resample)
|
||||
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
||||
return images_resample
|
||||
|
||||
def initialize_nnf(self, batch_size):
|
||||
if self.initialize == "random":
|
||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
||||
nnf = cp.stack([
|
||||
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
||||
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
||||
], axis=3)
|
||||
elif self.initialize == "identity":
|
||||
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
||||
nnf = cp.stack([
|
||||
cp.repeat(cp.arange(height), width).reshape(height, width),
|
||||
cp.tile(cp.arange(width), height).reshape(height, width)
|
||||
], axis=2)
|
||||
nnf = cp.stack([nnf] * batch_size)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return nnf
|
||||
|
||||
def update_nnf(self, nnf, level):
|
||||
# upscale
|
||||
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
||||
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
||||
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
||||
# check if scale is 2
|
||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
||||
nnf = nnf.get().astype(np.float32)
|
||||
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
||||
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
||||
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
||||
return nnf
|
||||
|
||||
def apply_nnf_to_image(self, nnf, image):
|
||||
with cp.cuda.Device(self.gpu_id):
|
||||
image = self.patch_matchers[-1].pad_image(image)
|
||||
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
||||
return image
|
||||
|
||||
def estimate_nnf(self, source_guide, target_guide, source_style):
|
||||
with cp.cuda.Device(self.gpu_id):
|
||||
if not isinstance(source_guide, cp.ndarray):
|
||||
source_guide = cp.array(source_guide, dtype=cp.float32)
|
||||
if not isinstance(target_guide, cp.ndarray):
|
||||
target_guide = cp.array(target_guide, dtype=cp.float32)
|
||||
if not isinstance(source_style, cp.ndarray):
|
||||
source_style = cp.array(source_style, dtype=cp.float32)
|
||||
for level in range(self.pyramid_level):
|
||||
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
||||
source_guide_ = self.resample_image(source_guide, level)
|
||||
target_guide_ = self.resample_image(target_guide, level)
|
||||
source_style_ = self.resample_image(source_style, level)
|
||||
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
||||
source_guide_, target_guide_, source_style_, nnf
|
||||
)
|
||||
return nnf.get(), target_style.get()
|
||||
4
diffsynth/extensions/FastBlend/runners/__init__.py
Normal file
4
diffsynth/extensions/FastBlend/runners/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .accurate import AccurateModeRunner
|
||||
from .fast import FastModeRunner
|
||||
from .balanced import BalancedModeRunner
|
||||
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
||||
35
diffsynth/extensions/FastBlend/runners/accurate.py
Normal file
35
diffsynth/extensions/FastBlend/runners/accurate.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class AccurateModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
use_mean_target_style=True,
|
||||
**ebsynth_config
|
||||
)
|
||||
# run
|
||||
n = len(frames_style)
|
||||
for target in tqdm(range(n), desc=desc):
|
||||
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
||||
remapped_frames = []
|
||||
for i in range(l, r, batch_size):
|
||||
j = min(i + batch_size, r)
|
||||
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
||||
target_guide = np.stack([frames_guide[target]] * (j - i))
|
||||
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
remapped_frames.append(target_style)
|
||||
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
if save_path is not None:
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
||||
46
diffsynth/extensions/FastBlend/runners/balanced.py
Normal file
46
diffsynth/extensions/FastBlend/runners/balanced.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class BalancedModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
**ebsynth_config
|
||||
)
|
||||
# tasks
|
||||
n = len(frames_style)
|
||||
tasks = []
|
||||
for target in range(n):
|
||||
for source in range(target - window_size, target + window_size + 1):
|
||||
if source >= 0 and source < n and source != target:
|
||||
tasks.append((source, target))
|
||||
# run
|
||||
frames = [(None, 1) for i in range(n)]
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
||||
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
||||
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for (source, target), result in zip(tasks_batch, target_style):
|
||||
frame, weight = frames[target]
|
||||
if frame is None:
|
||||
frame = frames_style[target]
|
||||
frames[target] = (
|
||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
||||
weight + 1
|
||||
)
|
||||
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
if save_path is not None:
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
||||
frames[target] = (None, 1)
|
||||
141
diffsynth/extensions/FastBlend/runners/fast.py
Normal file
141
diffsynth/extensions/FastBlend/runners/fast.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import functools, os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TableManager:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def task_list(self, n):
|
||||
tasks = []
|
||||
max_level = 1
|
||||
while (1<<max_level)<=n:
|
||||
max_level += 1
|
||||
for i in range(n):
|
||||
j = i
|
||||
for level in range(max_level):
|
||||
if i&(1<<level):
|
||||
continue
|
||||
j |= 1<<level
|
||||
if j>=n:
|
||||
break
|
||||
meta_data = {
|
||||
"source": i,
|
||||
"target": j,
|
||||
"level": level + 1
|
||||
}
|
||||
tasks.append(meta_data)
|
||||
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
||||
return tasks
|
||||
|
||||
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
||||
n = len(frames_guide)
|
||||
tasks = self.task_list(n)
|
||||
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
||||
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for task, result in zip(tasks_batch, target_style):
|
||||
target, level = task["target"], task["level"]
|
||||
if len(remapping_table[target])==level:
|
||||
remapping_table[target].append((result, 1))
|
||||
else:
|
||||
frame, weight = remapping_table[target][level]
|
||||
remapping_table[target][level] = (
|
||||
frame * (weight / (weight + 1)) + result / (weight + 1),
|
||||
weight + 1
|
||||
)
|
||||
return remapping_table
|
||||
|
||||
def remapping_table_to_blending_table(self, table):
|
||||
for i in range(len(table)):
|
||||
for j in range(1, len(table[i])):
|
||||
frame_1, weight_1 = table[i][j-1]
|
||||
frame_2, weight_2 = table[i][j]
|
||||
frame = (frame_1 + frame_2) / 2
|
||||
weight = weight_1 + weight_2
|
||||
table[i][j] = (frame, weight)
|
||||
return table
|
||||
|
||||
def tree_query(self, leftbound, rightbound):
|
||||
node_list = []
|
||||
node_index = rightbound
|
||||
while node_index>=leftbound:
|
||||
node_level = 0
|
||||
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
||||
node_level += 1
|
||||
node_list.append((node_index, node_level))
|
||||
node_index -= 1<<node_level
|
||||
return node_list
|
||||
|
||||
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
||||
n = len(blending_table)
|
||||
tasks = []
|
||||
frames_result = []
|
||||
for target in range(n):
|
||||
node_list = self.tree_query(max(target-window_size, 0), target)
|
||||
for source, level in node_list:
|
||||
if source!=target:
|
||||
meta_data = {
|
||||
"source": source,
|
||||
"target": target,
|
||||
"level": level
|
||||
}
|
||||
tasks.append(meta_data)
|
||||
else:
|
||||
frames_result.append(blending_table[target][level])
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
||||
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
||||
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for task, frame_2 in zip(tasks_batch, target_style):
|
||||
source, target, level = task["source"], task["target"], task["level"]
|
||||
frame_1, weight_1 = frames_result[target]
|
||||
weight_2 = blending_table[source][level][1]
|
||||
weight = weight_1 + weight_2
|
||||
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
||||
frames_result[target] = (frame, weight)
|
||||
return frames_result
|
||||
|
||||
|
||||
class FastModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
||||
frames_guide = frames_guide.raw_data()
|
||||
frames_style = frames_style.raw_data()
|
||||
table_manager = TableManager()
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
**ebsynth_config
|
||||
)
|
||||
# left part
|
||||
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
||||
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
||||
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
||||
# right part
|
||||
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
||||
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
||||
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
||||
# merge
|
||||
frames = []
|
||||
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
||||
weight_m = -1
|
||||
weight = weight_l + weight_m + weight_r
|
||||
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
||||
frames.append(frame)
|
||||
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
||||
if save_path is not None:
|
||||
for target, frame in enumerate(frames):
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
||||
121
diffsynth/extensions/FastBlend/runners/interpolation.py
Normal file
121
diffsynth/extensions/FastBlend/runners/interpolation.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from ..patch_match import PyramidPatchMatcher
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class InterpolationModeRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_index_dict(self, index_style):
|
||||
index_dict = {}
|
||||
for i, index in enumerate(index_style):
|
||||
index_dict[index] = i
|
||||
return index_dict
|
||||
|
||||
def get_weight(self, l, m, r):
|
||||
weight_l, weight_r = abs(m - r), abs(m - l)
|
||||
if weight_l + weight_r == 0:
|
||||
weight_l, weight_r = 0.5, 0.5
|
||||
else:
|
||||
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
||||
return weight_l, weight_r
|
||||
|
||||
def get_task_group(self, index_style, n):
|
||||
task_group = []
|
||||
index_style = sorted(index_style)
|
||||
# first frame
|
||||
if index_style[0]>0:
|
||||
tasks = []
|
||||
for m in range(index_style[0]):
|
||||
tasks.append((index_style[0], m, index_style[0]))
|
||||
task_group.append(tasks)
|
||||
# middle frames
|
||||
for l, r in zip(index_style[:-1], index_style[1:]):
|
||||
tasks = []
|
||||
for m in range(l, r):
|
||||
tasks.append((l, m, r))
|
||||
task_group.append(tasks)
|
||||
# last frame
|
||||
tasks = []
|
||||
for m in range(index_style[-1], n):
|
||||
tasks.append((index_style[-1], m, index_style[-1]))
|
||||
task_group.append(tasks)
|
||||
return task_group
|
||||
|
||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frames_style[0].shape[0],
|
||||
image_width=frames_style[0].shape[1],
|
||||
channel=3,
|
||||
use_mean_target_style=False,
|
||||
use_pairwise_patch_error=True,
|
||||
**ebsynth_config
|
||||
)
|
||||
# task
|
||||
index_dict = self.get_index_dict(index_style)
|
||||
task_group = self.get_task_group(index_style, len(frames_guide))
|
||||
# run
|
||||
for tasks in task_group:
|
||||
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
||||
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
||||
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
||||
source_guide, target_guide, source_style = [], [], []
|
||||
for l, m, r in tasks_batch:
|
||||
# l -> m
|
||||
source_guide.append(frames_guide[l])
|
||||
target_guide.append(frames_guide[m])
|
||||
source_style.append(frames_style[index_dict[l]])
|
||||
# r -> m
|
||||
source_guide.append(frames_guide[r])
|
||||
target_guide.append(frames_guide[m])
|
||||
source_style.append(frames_style[index_dict[r]])
|
||||
source_guide = np.stack(source_guide)
|
||||
target_guide = np.stack(target_guide)
|
||||
source_style = np.stack(source_style)
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
if save_path is not None:
|
||||
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
||||
weight_l, weight_r = self.get_weight(l, m, r)
|
||||
frame = frame_l * weight_l + frame_r * weight_r
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
||||
|
||||
|
||||
class InterpolationModeSingleFrameRunner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
||||
# check input
|
||||
tracking_window_size = ebsynth_config["tracking_window_size"]
|
||||
if tracking_window_size * 2 >= batch_size:
|
||||
raise ValueError("batch_size should be larger than track_window_size * 2")
|
||||
frame_style = frames_style[0]
|
||||
frame_guide = frames_guide[index_style[0]]
|
||||
patch_match_engine = PyramidPatchMatcher(
|
||||
image_height=frame_style.shape[0],
|
||||
image_width=frame_style.shape[1],
|
||||
channel=3,
|
||||
**ebsynth_config
|
||||
)
|
||||
# run
|
||||
frame_id, n = 0, len(frames_guide)
|
||||
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
||||
if i + batch_size > n:
|
||||
l, r = max(n - batch_size, 0), n
|
||||
else:
|
||||
l, r = i, i + batch_size
|
||||
source_guide = np.stack([frame_guide] * (r-l))
|
||||
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
||||
source_style = np.stack([frame_style] * (r-l))
|
||||
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
||||
for i, frame in zip(range(l, r), target_style):
|
||||
if i==frame_id:
|
||||
frame = frame.clip(0, 255).astype("uint8")
|
||||
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
||||
frame_id += 1
|
||||
if r < n and r-frame_id <= tracking_window_size:
|
||||
break
|
||||
242
diffsynth/extensions/RIFE/__init__.py
Normal file
242
diffsynth/extensions/RIFE/__init__.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def warp(tenInput, tenFlow, device):
|
||||
backwarp_tenGrid = {}
|
||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||
if k not in backwarp_tenGrid:
|
||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
||||
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||
backwarp_tenGrid[k] = torch.cat(
|
||||
[tenHorizontal, tenVertical], 1).to(device)
|
||||
|
||||
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
||||
|
||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, c=64):
|
||||
super(IFBlock, self).__init__()
|
||||
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
||||
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
||||
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
||||
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
||||
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
||||
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
||||
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
||||
|
||||
def forward(self, x, flow, scale=1):
|
||||
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
||||
feat = self.conv0(torch.cat((x, flow), 1))
|
||||
feat = self.convblock0(feat) + feat
|
||||
feat = self.convblock1(feat) + feat
|
||||
feat = self.convblock2(feat) + feat
|
||||
feat = self.convblock3(feat) + feat
|
||||
flow = self.conv1(feat)
|
||||
mask = self.conv2(feat)
|
||||
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
||||
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
return flow, mask
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(7+4, c=90)
|
||||
self.block1 = IFBlock(7+4, c=90)
|
||||
self.block2 = IFBlock(7+4, c=90)
|
||||
self.block_tea = IFBlock(10+4, c=90)
|
||||
|
||||
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
||||
if training == False:
|
||||
channel = x.shape[1] // 2
|
||||
img0 = x[:, :channel]
|
||||
img1 = x[:, channel:]
|
||||
flow_list = []
|
||||
merged = []
|
||||
mask_list = []
|
||||
warped_img0 = img0
|
||||
warped_img1 = img1
|
||||
flow = (x[:, :4]).detach() * 0
|
||||
mask = (x[:, :1]).detach() * 0
|
||||
block = [self.block0, self.block1, self.block2]
|
||||
for i in range(3):
|
||||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
||||
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
||||
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
||||
mask = mask + (m0 + (-m1)) / 2
|
||||
mask_list.append(mask)
|
||||
flow_list.append(flow)
|
||||
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
||||
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
||||
merged.append((warped_img0, warped_img1))
|
||||
'''
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
res = tmp[:, 1:4] * 2 - 1
|
||||
'''
|
||||
for i in range(3):
|
||||
mask_list[i] = torch.sigmoid(mask_list[i])
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
return flow_list, mask_list[2], merged
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return IFNetStateDictConverter()
|
||||
|
||||
|
||||
class IFNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
|
||||
|
||||
|
||||
class RIFEInterpolater:
|
||||
def __init__(self, model, device="cuda"):
|
||||
self.model = model
|
||||
self.device = device
|
||||
# IFNet only does not support float16
|
||||
self.torch_dtype = torch.float32
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||
|
||||
def process_image(self, image):
|
||||
width, height = image.size
|
||||
if width % 32 != 0 or height % 32 != 0:
|
||||
width = (width + 31) // 32
|
||||
height = (height + 31) // 32
|
||||
image = image.resize((width, height))
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
||||
return image
|
||||
|
||||
def process_images(self, images):
|
||||
images = [self.process_image(image) for image in images]
|
||||
images = torch.stack(images)
|
||||
return images
|
||||
|
||||
def decode_images(self, images):
|
||||
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
||||
images = [Image.fromarray(image) for image in images]
|
||||
return images
|
||||
|
||||
def add_interpolated_images(self, images, interpolated_images):
|
||||
output_images = []
|
||||
for image, interpolated_image in zip(images, interpolated_images):
|
||||
output_images.append(image)
|
||||
output_images.append(interpolated_image)
|
||||
output_images.append(images[-1])
|
||||
return output_images
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_(self, images, scale=1.0):
|
||||
input_tensor = self.process_images(images)
|
||||
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
||||
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
||||
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
||||
output_images = self.decode_images(merged[2].cpu())
|
||||
if output_images[0].size != images[0].size:
|
||||
output_images = [image.resize(images[0].size) for image in output_images]
|
||||
return output_images
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
||||
# Preprocess
|
||||
processed_images = self.process_images(images)
|
||||
|
||||
for iter in range(num_iter):
|
||||
# Input
|
||||
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
||||
|
||||
# Interpolate
|
||||
output_tensor = []
|
||||
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
||||
output_tensor.append(merged[2].cpu())
|
||||
|
||||
# Output
|
||||
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
||||
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
||||
processed_images = torch.stack(processed_images)
|
||||
|
||||
# To images
|
||||
output_images = self.decode_images(processed_images)
|
||||
if output_images[0].size != images[0].size:
|
||||
output_images = [image.resize(images[0].size) for image in output_images]
|
||||
return output_images
|
||||
|
||||
|
||||
class RIFESmoother(RIFEInterpolater):
|
||||
def __init__(self, model, device="cuda"):
|
||||
super(RIFESmoother, self).__init__(model, device=device)
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager):
|
||||
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
|
||||
|
||||
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
||||
output_tensor = []
|
||||
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
||||
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
||||
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
||||
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
||||
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
||||
output_tensor.append(merged[2].cpu())
|
||||
output_tensor = torch.concat(output_tensor, dim=0)
|
||||
return output_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
||||
# Preprocess
|
||||
processed_images = self.process_images(rendered_frames)
|
||||
|
||||
for iter in range(num_iter):
|
||||
# Input
|
||||
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
||||
|
||||
# Interpolate
|
||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
||||
|
||||
# Blend
|
||||
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
||||
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
||||
|
||||
# Add to frames
|
||||
processed_images[1:-1] = output_tensor
|
||||
|
||||
# To images
|
||||
output_images = self.decode_images(processed_images)
|
||||
if output_images[0].size != rendered_frames[0].size:
|
||||
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
||||
return output_images
|
||||
1
diffsynth/models/__init__.py
Normal file
1
diffsynth/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .model_manager import *
|
||||
File diff suppressed because it is too large
Load Diff
89
diffsynth/models/attention.py
Normal file
89
diffsynth/models/attention.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def low_version_attention(query, key, value, attn_bias=None):
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ value
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size = q.shape[0]
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if qkv_preprocessor is not None:
|
||||
q, k, v = qkv_preprocessor(q, k, v)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
if ipadapter_kwargs is not None:
|
||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
|
||||
if attn_mask is not None:
|
||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||
else:
|
||||
import xformers.ops as xops
|
||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||
395
diffsynth/models/cog_dit.py
Normal file
395
diffsynth/models/cog_dit.py
Normal file
@@ -0,0 +1,395 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .sd3_dit import TimestepEmbeddings
|
||||
from .attention import Attention
|
||||
from .utils import load_state_dict_from_folder
|
||||
from .tiler import TileWorker2Dto3D
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class CogPatchify(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, patch_size) -> None:
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class CogAdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, dim_cond, single=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
|
||||
|
||||
|
||||
def forward(self, hidden_states, prompt_emb, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
||||
return hidden_states
|
||||
else:
|
||||
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
|
||||
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
|
||||
return hidden_states, prompt_emb, gate_a, gate_b
|
||||
|
||||
|
||||
|
||||
class CogDiTBlock(torch.nn.Module):
|
||||
def __init__(self, dim, dim_cond, num_heads):
|
||||
super().__init__()
|
||||
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
|
||||
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||
|
||||
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb(self, x, freqs_cis):
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
|
||||
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
|
||||
# Attention
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
|
||||
hidden_states, prompt_emb, time_emb
|
||||
)
|
||||
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
attention_io = self.attn1(
|
||||
attention_io,
|
||||
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
|
||||
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
|
||||
|
||||
# Feed forward
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
|
||||
hidden_states, prompt_emb, time_emb
|
||||
)
|
||||
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_io = self.ff(ff_io)
|
||||
|
||||
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
|
||||
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
|
||||
|
||||
return hidden_states, prompt_emb
|
||||
|
||||
|
||||
|
||||
class CogDiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.patchify = CogPatchify(16, 3072, 2)
|
||||
self.time_embedder = TimestepEmbeddings(3072, 512)
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
|
||||
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
|
||||
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
|
||||
|
||||
|
||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
):
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
):
|
||||
grid_height = height // 2
|
||||
grid_width = width // 2
|
||||
base_size_width = 720 // (8 * 2)
|
||||
base_size_height = 480 // (8 * 2)
|
||||
|
||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
|
||||
embed_dim=64,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def build_mask(self, T, H, W, dtype, device, is_bound):
|
||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else t + 1,
|
||||
pad if is_bound[1] else T - t,
|
||||
pad if is_bound[2] else h + 1,
|
||||
pad if is_bound[3] else H - h,
|
||||
pad if is_bound[4] else w + 1,
|
||||
pad if is_bound[5] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = max(H - tile_size, 0), H
|
||||
if w_ > W: w, w_ = max(W - tile_size, 0), W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
mask = self.build_mask(
|
||||
value.shape[2], (hr-hl), (wr-wl),
|
||||
hidden_states.dtype, hidden_states.device,
|
||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
|
||||
)
|
||||
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
|
||||
value[:, :, :, hl:hr, wl:wr] += model_output * mask
|
||||
weight[:, :, :, hl:hr, wl:wr] += mask
|
||||
value = value / weight
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
|
||||
if tiled:
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
||||
model_input=hidden_states,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
|
||||
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
|
||||
)
|
||||
num_frames, height, width = hidden_states.shape[-3:]
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
for block in self.blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogDiTStateDictConverter()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
|
||||
model = CogDiT().to(torch_dtype)
|
||||
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
|
||||
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
class CogDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"patch_embed.proj.weight": "patchify.proj.weight",
|
||||
"patch_embed.proj.bias": "patchify.proj.bias",
|
||||
"patch_embed.text_proj.weight": "context_embedder.weight",
|
||||
"patch_embed.text_proj.bias": "context_embedder.bias",
|
||||
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
|
||||
"norm_final.weight": "norm_final.weight",
|
||||
"norm_final.bias": "norm_final.bias",
|
||||
"norm_out.linear.weight": "norm_out.linear.weight",
|
||||
"norm_out.linear.bias": "norm_out.linear.bias",
|
||||
"norm_out.norm.weight": "norm_out.norm.weight",
|
||||
"norm_out.norm.bias": "norm_out.norm.bias",
|
||||
"proj_out.weight": "proj_out.weight",
|
||||
"proj_out.bias": "proj_out.bias",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.linear.weight": "norm1.linear.weight",
|
||||
"norm1.linear.bias": "norm1.linear.bias",
|
||||
"norm1.norm.weight": "norm1.norm.weight",
|
||||
"norm1.norm.bias": "norm1.norm.bias",
|
||||
"attn1.norm_q.weight": "norm_q.weight",
|
||||
"attn1.norm_q.bias": "norm_q.bias",
|
||||
"attn1.norm_k.weight": "norm_k.weight",
|
||||
"attn1.norm_k.bias": "norm_k.bias",
|
||||
"attn1.to_q.weight": "attn1.to_q.weight",
|
||||
"attn1.to_q.bias": "attn1.to_q.bias",
|
||||
"attn1.to_k.weight": "attn1.to_k.weight",
|
||||
"attn1.to_k.bias": "attn1.to_k.bias",
|
||||
"attn1.to_v.weight": "attn1.to_v.weight",
|
||||
"attn1.to_v.bias": "attn1.to_v.bias",
|
||||
"attn1.to_out.0.weight": "attn1.to_out.weight",
|
||||
"attn1.to_out.0.bias": "attn1.to_out.bias",
|
||||
"norm2.linear.weight": "norm2.linear.weight",
|
||||
"norm2.linear.bias": "norm2.linear.bias",
|
||||
"norm2.norm.weight": "norm2.norm.weight",
|
||||
"norm2.norm.bias": "norm2.norm.bias",
|
||||
"ff.net.0.proj.weight": "ff.0.weight",
|
||||
"ff.net.0.proj.bias": "ff.0.bias",
|
||||
"ff.net.2.weight": "ff.2.weight",
|
||||
"ff.net.2.bias": "ff.2.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "patch_embed.proj.weight":
|
||||
param = param.unsqueeze(2)
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
names = name.split(".")
|
||||
if names[0] == "transformer_blocks":
|
||||
suffix = ".".join(names[2:])
|
||||
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
518
diffsynth/models/cog_vae.py
Normal file
518
diffsynth/models/cog_vae.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .tiler import TileWorker2Dto3D
|
||||
|
||||
|
||||
|
||||
class Downsample3D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 2,
|
||||
padding: int = 0,
|
||||
compress_time: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
# Pad the tensor
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||
x = self.conv(x)
|
||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Upsample3D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: int = 1,
|
||||
compress_time: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
|
||||
x_first = x_first[:, :, None, :, :]
|
||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||
elif inputs.shape[2] > 1:
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
else:
|
||||
inputs = inputs.squeeze(2)
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs[:, :, None, :, :]
|
||||
else:
|
||||
# only interpolate 2D
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = self.conv(inputs)
|
||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
|
||||
class CogVideoXSpatialNorm3D(torch.nn.Module):
|
||||
def __init__(self, f_channels, zq_channels, groups):
|
||||
super().__init__()
|
||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
|
||||
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
|
||||
zq = torch.cat([z_first, z_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
|
||||
class Resnet3DBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
|
||||
super().__init__()
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
|
||||
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
|
||||
|
||||
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
|
||||
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
|
||||
if in_channels != out_channels:
|
||||
if use_conv_shortcut:
|
||||
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
else:
|
||||
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
||||
else:
|
||||
self.conv_shortcut = lambda x: x
|
||||
|
||||
|
||||
def forward(self, hidden_states, zq):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class CachedConv3d(torch.nn.Conv3d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.cached_tensor = None
|
||||
|
||||
|
||||
def clear_cache(self):
|
||||
self.cached_tensor = None
|
||||
|
||||
|
||||
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
|
||||
if use_cache:
|
||||
if self.cached_tensor is None:
|
||||
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
|
||||
input = torch.concat([self.cached_tensor, input], dim=2)
|
||||
self.cached_tensor = input[:, :, -2:]
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
|
||||
class CogVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.7
|
||||
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Upsample3D(512, 512, compress_time=True),
|
||||
Resnet3DBlock(512, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Upsample3D(256, 256, compress_time=True),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Upsample3D(256, 256, compress_time=False),
|
||||
Resnet3DBlock(256, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
])
|
||||
|
||||
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
|
||||
def forward(self, sample):
|
||||
sample = sample / self.scaling_factor
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, sample)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||
if tiled:
|
||||
B, C, T, H, W = sample.shape
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.decode_small_video(x),
|
||||
model_input=sample,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
|
||||
progress_bar=progress_bar
|
||||
)
|
||||
else:
|
||||
return self.decode_small_video(sample)
|
||||
|
||||
|
||||
def decode_small_video(self, sample):
|
||||
B, C, T, H, W = sample.shape
|
||||
computation_device = self.conv_in.weight.device
|
||||
computation_dtype = self.conv_in.weight.dtype
|
||||
value = []
|
||||
for i in range(T//2):
|
||||
tl = i*2 + T%2 - (T%2 and i==0)
|
||||
tr = i*2 + 2 + T%2
|
||||
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||
value.append(model_output)
|
||||
value = torch.concat(value, dim=2)
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CachedConv3d):
|
||||
module.clear_cache()
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class CogVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.7
|
||||
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Downsample3D(128, 128, compress_time=True),
|
||||
Resnet3DBlock(128, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Downsample3D(256, 256, compress_time=True),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Downsample3D(256, 256, compress_time=False),
|
||||
Resnet3DBlock(256, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
])
|
||||
|
||||
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
|
||||
def forward(self, sample):
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, sample)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)[:, :16]
|
||||
hidden_states = hidden_states * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||
if tiled:
|
||||
B, C, T, H, W = sample.shape
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.encode_small_video(x),
|
||||
model_input=sample,
|
||||
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
|
||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
|
||||
progress_bar=progress_bar
|
||||
)
|
||||
else:
|
||||
return self.encode_small_video(sample)
|
||||
|
||||
|
||||
def encode_small_video(self, sample):
|
||||
B, C, T, H, W = sample.shape
|
||||
computation_device = self.conv_in.weight.device
|
||||
computation_dtype = self.conv_in.weight.dtype
|
||||
value = []
|
||||
for i in range(T//8):
|
||||
t = i*8 + T%2 - (T%2 and i==0)
|
||||
t_ = i*8 + 8 + T%2
|
||||
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||
value.append(model_output)
|
||||
value = torch.concat(value, dim=2)
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CachedConv3d):
|
||||
module.clear_cache()
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class CogVAEEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"encoder.conv_in.conv.weight": "conv_in.weight",
|
||||
"encoder.conv_in.conv.bias": "conv_in.bias",
|
||||
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
|
||||
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
|
||||
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
|
||||
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
|
||||
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||
"encoder.norm_out.weight": "norm_out.weight",
|
||||
"encoder.norm_out.bias": "norm_out.bias",
|
||||
"encoder.conv_out.conv.weight": "conv_out.weight",
|
||||
"encoder.conv_out.conv.bias": "conv_out.bias",
|
||||
}
|
||||
prefix_dict = {
|
||||
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
|
||||
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
|
||||
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
|
||||
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
|
||||
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
|
||||
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
|
||||
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
|
||||
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
|
||||
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
|
||||
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
|
||||
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
|
||||
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
|
||||
"encoder.mid_block.resnets.0.": "blocks.15.",
|
||||
"encoder.mid_block.resnets.1.": "blocks.16.",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||
"conv1.conv.weight": "conv1.weight",
|
||||
"conv1.conv.bias": "conv1.bias",
|
||||
"conv2.conv.weight": "conv2.weight",
|
||||
"conv2.conv.bias": "conv2.bias",
|
||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||
"norm1.weight": "norm1.weight",
|
||||
"norm1.bias": "norm1.bias",
|
||||
"norm2.weight": "norm2.weight",
|
||||
"norm2.bias": "norm2.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
for prefix in prefix_dict:
|
||||
if name.startswith(prefix):
|
||||
suffix = name[len(prefix):]
|
||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
|
||||
class CogVAEDecoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"decoder.conv_in.conv.weight": "conv_in.weight",
|
||||
"decoder.conv_in.conv.bias": "conv_in.bias",
|
||||
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
|
||||
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
|
||||
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
|
||||
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
|
||||
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
|
||||
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
|
||||
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
|
||||
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
|
||||
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
|
||||
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
|
||||
"decoder.conv_out.conv.weight": "conv_out.weight",
|
||||
"decoder.conv_out.conv.bias": "conv_out.bias"
|
||||
}
|
||||
prefix_dict = {
|
||||
"decoder.mid_block.resnets.0.": "blocks.0.",
|
||||
"decoder.mid_block.resnets.1.": "blocks.1.",
|
||||
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
|
||||
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
|
||||
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
|
||||
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
|
||||
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
|
||||
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
|
||||
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
|
||||
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
|
||||
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
|
||||
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
|
||||
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
|
||||
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
|
||||
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
|
||||
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
|
||||
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
|
||||
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||
"conv1.conv.weight": "conv1.weight",
|
||||
"conv1.conv.bias": "conv1.bias",
|
||||
"conv2.conv.weight": "conv2.weight",
|
||||
"conv2.conv.bias": "conv2.bias",
|
||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
for prefix in prefix_dict:
|
||||
if name.startswith(prefix):
|
||||
suffix = name[len(prefix):]
|
||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||
import torch
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
|
||||
|
||||
class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||
def __init__(self):
|
||||
config = DINOv3ViTConfig(
|
||||
architectures = [
|
||||
"DINOv3ViTModel"
|
||||
],
|
||||
attention_dropout = 0.0,
|
||||
drop_path_rate = 0.0,
|
||||
dtype = "float32",
|
||||
hidden_act = "silu",
|
||||
hidden_size = 4096,
|
||||
image_size = 224,
|
||||
initializer_range = 0.02,
|
||||
intermediate_size = 8192,
|
||||
key_bias = False,
|
||||
layer_norm_eps = 1e-05,
|
||||
layerscale_value = 1.0,
|
||||
mlp_bias = True,
|
||||
model_type = "dinov3_vit",
|
||||
num_attention_heads = 32,
|
||||
num_channels = 3,
|
||||
num_hidden_layers = 40,
|
||||
num_register_tokens = 4,
|
||||
patch_size = 16,
|
||||
pos_embed_jitter = None,
|
||||
pos_embed_rescale = 2.0,
|
||||
pos_embed_shift = None,
|
||||
proj_bias = True,
|
||||
query_bias = False,
|
||||
rope_theta = 100.0,
|
||||
transformers_version = "4.56.1",
|
||||
use_gated_mlp = True,
|
||||
value_bias = False
|
||||
)
|
||||
super().__init__(config)
|
||||
self.processor = DINOv3ViTImageProcessorFast(
|
||||
crop_size = None,
|
||||
data_format = "channels_first",
|
||||
default_to_square = True,
|
||||
device = None,
|
||||
disable_grouping = None,
|
||||
do_center_crop = None,
|
||||
do_convert_rgb = None,
|
||||
do_normalize = True,
|
||||
do_rescale = True,
|
||||
do_resize = True,
|
||||
image_mean = [
|
||||
0.485,
|
||||
0.456,
|
||||
0.406
|
||||
],
|
||||
image_processor_type = "DINOv3ViTImageProcessorFast",
|
||||
image_std = [
|
||||
0.229,
|
||||
0.224,
|
||||
0.225
|
||||
],
|
||||
input_data_format = None,
|
||||
resample = 2,
|
||||
rescale_factor = 0.00392156862745098,
|
||||
return_tensors = None,
|
||||
size = {
|
||||
"height": 224,
|
||||
"width": 224
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
||||
bool_masked_pos = None
|
||||
head_mask = None
|
||||
|
||||
pixel_values = pixel_values.to(torch_dtype)
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
attention_mask=layer_head_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
sequence_output = self.norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
return pooled_output
|
||||
66
diffsynth/models/downloader.py
Normal file
66
diffsynth/models/downloader.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
|
||||
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
"ModelScope",
|
||||
]
|
||||
website_to_preset_models = {
|
||||
"HuggingFace": preset_models_on_huggingface,
|
||||
"ModelScope": preset_models_on_modelscope,
|
||||
}
|
||||
website_to_download_fn = {
|
||||
"HuggingFace": download_from_huggingface,
|
||||
"ModelScope": download_from_modelscope,
|
||||
}
|
||||
|
||||
|
||||
def download_models(
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
print(f"Downloading models: {model_id_list}")
|
||||
downloaded_files = []
|
||||
for model_id in model_id_list:
|
||||
for website in downloading_priority:
|
||||
if model_id in website_to_preset_models[website]:
|
||||
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
|
||||
|
||||
|
||||
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
|
||||
def __init__(self):
|
||||
config = Mistral3Config(**{
|
||||
"architectures": [
|
||||
"Mistral3ForConditionalGeneration"
|
||||
],
|
||||
"dtype": "bfloat16",
|
||||
"image_token_index": 10,
|
||||
"model_type": "mistral3",
|
||||
"multimodal_projector_bias": False,
|
||||
"projector_hidden_act": "gelu",
|
||||
"spatial_merge_size": 2,
|
||||
"text_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 5120,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 32768,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "mistral",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 40,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 1000000000.0,
|
||||
"sliding_window": None,
|
||||
"use_cache": True,
|
||||
"vocab_size": 131072
|
||||
},
|
||||
"transformers_version": "4.57.1",
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"head_dim": 64,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 1540,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"model_type": "pixtral",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"rope_theta": 10000.0
|
||||
},
|
||||
"vision_feature_layer": -1
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
|
||||
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,384 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||
# from .utils import hash_state_dict_keys, init_weights_on_device
|
||||
from contextlib import contextmanager
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
@contextmanager
|
||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = torch.nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(device)
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if include_buffers:
|
||||
tensor_constructors_to_patch = {
|
||||
torch_function_name: getattr(torch, torch_function_name)
|
||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||
}
|
||||
else:
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
try:
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = register_empty_buffer
|
||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = old_register_buffer
|
||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
class FluxControlNet(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
||||
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
||||
|
||||
self.mode_dict = mode_dict
|
||||
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
||||
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
||||
if len(res_stack) == 0:
|
||||
return [torch.zeros_like(hidden_states)] * num_blocks
|
||||
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
||||
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
||||
return aligned_res_stack
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
controlnet_conditioning,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
processor_id=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
||||
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
||||
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
||||
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
||||
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
||||
|
||||
controlnet_res_stack = []
|
||||
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_res_stack.append(controlnet_block(hidden_states))
|
||||
|
||||
controlnet_single_res_stack = []
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
||||
|
||||
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
||||
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
||||
|
||||
return controlnet_res_stack, controlnet_single_res_stack
|
||||
|
||||
|
||||
# @staticmethod
|
||||
# def state_dict_converter():
|
||||
# return FluxControlNetStateDictConverter()
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
bias = None
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device)
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class QLinear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class QRMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
class QEmbedding(torch.nn.Embedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight= cast_weight(self,input)
|
||||
return torch.nn.functional.embedding(
|
||||
input, weight, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module,quantized_layer.QRMSNorm):
|
||||
continue
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
||||
new_layer.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_layer.bias = module.bias
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
if hasattr(module,"quantized"):
|
||||
continue
|
||||
module.quantized= True
|
||||
new_layer = quantized_layer.QRMSNorm(module)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module,torch.nn.Embedding):
|
||||
rows, cols = module.weight.shape
|
||||
new_layer = quantized_layer.QEmbedding(
|
||||
num_embeddings=rows,
|
||||
embedding_dim=cols,
|
||||
_weight=module.weight,
|
||||
# _freeze=module.freeze,
|
||||
padding_idx=module.padding_idx,
|
||||
max_norm=module.max_norm,
|
||||
norm_type=module.norm_type,
|
||||
scale_grad_by_freq=module.scale_grad_by_freq,
|
||||
sparse=module.sparse)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module)
|
||||
|
||||
replace_layer(self)
|
||||
|
||||
|
||||
|
||||
class FluxControlNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
hash_value = hash_state_dict_keys(state_dict)
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
||||
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
||||
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
||||
extra_kwargs = {"num_single_blocks": 0}
|
||||
elif hash_value == "52357cb26250681367488a8954c271e8":
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
||||
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
return state_dict_, extra_kwargs
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,15 +1,9 @@
|
||||
import torch
|
||||
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RoPEEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim):
|
||||
@@ -39,8 +33,23 @@ class RoPEEmbedding(torch.nn.Module):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class FluxJointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||
@@ -69,7 +78,8 @@ class FluxJointAttention(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
# Part A
|
||||
@@ -90,19 +100,17 @@ class FluxJointAttention(torch.nn.Module):
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxJointTransformerBlock(torch.nn.Module):
|
||||
@@ -128,12 +136,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
@@ -146,7 +154,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxSingleAttention(torch.nn.Module):
|
||||
@@ -183,7 +191,7 @@ class FluxSingleAttention(torch.nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormSingle(torch.nn.Module):
|
||||
@@ -199,7 +207,7 @@ class AdaLayerNormSingle(torch.nn.Module):
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
@@ -224,8 +232,8 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -234,29 +242,27 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb)
|
||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||
|
||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||
hidden_states_a = residual + hidden_states_a
|
||||
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(torch.nn.Module):
|
||||
@@ -268,29 +274,27 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
def forward(self, x, conditioning):
|
||||
emb = self.linear(self.silu(conditioning))
|
||||
shift, scale = torch.chunk(emb, 2, dim=1)
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
self.input_dim = input_dim
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
@@ -301,7 +305,7 @@ class FluxDiT(torch.nn.Module):
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
@@ -318,78 +322,272 @@ class FluxDiT(torch.nn.Module):
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
|
||||
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
total_seq_len = N * prompt_seq_len + image_seq_len
|
||||
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
image_start = N * prompt_seq_len
|
||||
image_end = N * prompt_seq_len + image_seq_len
|
||||
# prompt-image mask
|
||||
for i in range(N):
|
||||
prompt_start = i * prompt_seq_len
|
||||
prompt_end = (i + 1) * prompt_seq_len
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt mask
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i != j:
|
||||
prompt_start_i = i * prompt_seq_len
|
||||
prompt_end_i = (i + 1) * prompt_seq_len
|
||||
prompt_start_j = j * prompt_seq_len
|
||||
prompt_end_j = (j + 1) * prompt_seq_len
|
||||
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
if entity_masks is not None:
|
||||
# entity_masks
|
||||
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
# global mask
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
entity_masks = entity_masks + [global_mask] # append global to last
|
||||
# attention mask
|
||||
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
# embds: n_masks * b * seq * d
|
||||
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
prompt_embs = local_embs + prompt_embs # append global to last
|
||||
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||
|
||||
# positional embedding
|
||||
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
return prompt_emb, image_rotary_emb, attention_mask
|
||||
def tiled_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
# (Deprecated) The real forward is in `pipelines.flux_image`.
|
||||
return None
|
||||
if tiled:
|
||||
return self.tiled_forward(
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
||||
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
||||
+ self.pooled_text_embedder(pooled_prompt_emb)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.final_proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"txt_in.bias": "context_embedder.bias",
|
||||
"txt_in.weight": "context_embedder.weight",
|
||||
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
||||
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
||||
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
||||
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
||||
"final_layer.linear.bias": "final_proj_out.bias",
|
||||
"final_layer.linear.weight": "final_proj_out.weight",
|
||||
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
||||
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
||||
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
||||
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
||||
"img_in.bias": "x_embedder.bias",
|
||||
"img_in.weight": "x_embedder.weight",
|
||||
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
|
||||
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
||||
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
||||
"img_attn.proj.bias": "attn.a_to_out.bias",
|
||||
"img_attn.proj.weight": "attn.a_to_out.weight",
|
||||
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
||||
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
||||
"img_mlp.0.bias": "ff_a.0.bias",
|
||||
"img_mlp.0.weight": "ff_a.0.weight",
|
||||
"img_mlp.2.bias": "ff_a.2.bias",
|
||||
"img_mlp.2.weight": "ff_a.2.weight",
|
||||
"img_mod.lin.bias": "norm1_a.linear.bias",
|
||||
"img_mod.lin.weight": "norm1_a.linear.weight",
|
||||
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
||||
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
||||
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
||||
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
||||
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
||||
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
||||
"txt_mlp.0.bias": "ff_b.0.bias",
|
||||
"txt_mlp.0.weight": "ff_b.0.weight",
|
||||
"txt_mlp.2.bias": "ff_b.2.bias",
|
||||
"txt_mlp.2.weight": "ff_b.2.weight",
|
||||
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
||||
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
||||
|
||||
"linear1.bias": "to_qkv_mlp.bias",
|
||||
"linear1.weight": "to_qkv_mlp.weight",
|
||||
"linear2.bias": "proj_out.bias",
|
||||
"linear2.weight": "proj_out.weight",
|
||||
"modulation.lin.bias": "norm.linear.bias",
|
||||
"modulation.lin.weight": "norm.linear.weight",
|
||||
"norm.key_norm.scale": "norm_k_a.weight",
|
||||
"norm.query_norm.scale": "norm_q_a.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
if name in rename_dict:
|
||||
rename = rename_dict[name]
|
||||
if name.startswith("final_layer.adaLN_modulation.1."):
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "double_blocks":
|
||||
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "single_blocks":
|
||||
if ".".join(names[2:]) in suffix_rename_dict:
|
||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
pass
|
||||
return state_dict_
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class InfiniteYouImageProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=1280,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=20,
|
||||
num_queries=8,
|
||||
embedding_dim=512,
|
||||
output_dim=4096,
|
||||
ff_mult=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
latents = latents.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxInfiniteYouImageProjectorStateDictConverter()
|
||||
|
||||
|
||||
class FluxInfiniteYouImageProjectorStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict['image_proj']
|
||||
@@ -1,110 +0,0 @@
|
||||
from .general_modules import RMSNorm
|
||||
from transformers import SiglipVisionModel, SiglipVisionConfig
|
||||
import torch
|
||||
|
||||
|
||||
class SiglipVisionModelSO400M(SiglipVisionModel):
|
||||
def __init__(self):
|
||||
config = SiglipVisionConfig(
|
||||
hidden_size=1152,
|
||||
image_size=384,
|
||||
intermediate_size=4304,
|
||||
model_type="siglip_vision_model",
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=27,
|
||||
patch_size=14,
|
||||
architectures=["SiglipModel"],
|
||||
initializer_factor=1.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.37.0.dev0"
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, id_embeds):
|
||||
x = self.proj(id_embeds)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class IpAdapterModule(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = attention_head_dim
|
||||
output_dim = num_attention_heads * attention_head_dim
|
||||
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
||||
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch_size = hidden_states.shape[0]
|
||||
# ip_k
|
||||
ip_k = self.to_k_ip(hidden_states)
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_k = self.norm_added_k(ip_k)
|
||||
# ip_v
|
||||
ip_v = self.to_v_ip(hidden_states)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
return ip_k, ip_v
|
||||
|
||||
|
||||
class FluxIpAdapter(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
||||
super().__init__()
|
||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
||||
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
||||
self.set_adapter()
|
||||
|
||||
def set_adapter(self):
|
||||
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||
ip_kv_dict = {}
|
||||
for block_id in self.call_block_id:
|
||||
ipadapter_id = self.call_block_id[block_id]
|
||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||
ip_kv_dict[block_id] = {
|
||||
"ip_k": ip_k,
|
||||
"ip_v": ip_v,
|
||||
"scale": scale
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
class FluxIpAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict["ip_adapter"]:
|
||||
name_ = 'ipadapter_modules.' + name
|
||||
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||
for name in state_dict["image_proj"]:
|
||||
name_ = "image_proj." + name
|
||||
state_dict_[name_] = state_dict["image_proj"][name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,521 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def low_version_attention(query, key, value, attn_bias=None):
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ value
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size = q.shape[0]
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if qkv_preprocessor is not None:
|
||||
q, k, v = qkv_preprocessor(q, k, v)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
if ipadapter_kwargs is not None:
|
||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
|
||||
if attn_mask is not None:
|
||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||
else:
|
||||
import xformers.ops as xops
|
||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class CLIPEncoderLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||
super().__init__()
|
||||
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||
|
||||
self.use_quick_gelu = use_quick_gelu
|
||||
|
||||
def quickGELU(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward(self, hidden_states, attn_mask=None):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
if self.use_quick_gelu:
|
||||
hidden_states = self.quickGELU(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SDTextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
|
||||
class LoRALayerBlock(torch.nn.Module):
|
||||
def __init__(self, L, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||
self.layer_norm = torch.nn.LayerNorm(dim_out)
|
||||
|
||||
def forward(self, lora_A, lora_B):
|
||||
x = self.x @ lora_A.T @ lora_B.T
|
||||
x = self.layer_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class LoRAEmbedder(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
proj_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
|
||||
if layer_type not in proj_dict:
|
||||
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
|
||||
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||
|
||||
self.lora_patterns = lora_patterns
|
||||
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, lora):
|
||||
lora_emb = []
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
lora_A = lora[name + ".lora_A.weight"]
|
||||
lora_B = lora[name + ".lora_B.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
lora_emb = torch.concat(lora_emb, dim=1)
|
||||
return lora_emb
|
||||
|
||||
|
||||
class FluxLoRAEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
|
||||
super().__init__()
|
||||
self.num_embeds_per_lora = num_embeds_per_lora
|
||||
# embedder
|
||||
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
|
||||
|
||||
# special embedding
|
||||
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
|
||||
self.num_special_embeds = num_special_embeds
|
||||
|
||||
# final layer
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
def forward(self, lora):
|
||||
lora_embeds = self.embedder(lora)
|
||||
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
|
||||
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds)
|
||||
embeds = embeds[:, :self.num_special_embeds]
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
embeds = self.final_linear(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxLoRAEncoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxLoRAEncoderStateDictConverter:
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,306 +0,0 @@
|
||||
import torch, math
|
||||
from ..core.loader import load_state_dict
|
||||
from typing import Union
|
||||
|
||||
class GeneralLoRALoader:
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
updated_num = 0
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
state_dict = module.state_dict()
|
||||
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||
module.load_state_dict(state_dict)
|
||||
updated_num += 1
|
||||
print(f"{updated_num} tensors are updated by LoRA.")
|
||||
|
||||
class FluxLoRALoader(GeneralLoRALoader):
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
|
||||
self.diffusers_rename_dict = {
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
}
|
||||
|
||||
self.civitai_rename_dict = {
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
}
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
super().load(model, state_dict_lora, alpha)
|
||||
|
||||
|
||||
def convert_state_dict(self,state_dict):
|
||||
|
||||
def guess_block_id(name,model_resource):
|
||||
if model_resource == 'civitai':
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
if model_resource == 'diffusers':
|
||||
names = name.split(".")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
||||
return None, None
|
||||
|
||||
def guess_resource(state_dict):
|
||||
for k in state_dict:
|
||||
if "lora_unet_" in k:
|
||||
return 'civitai'
|
||||
elif k.startswith("transformer."):
|
||||
return 'diffusers'
|
||||
else:
|
||||
None
|
||||
|
||||
model_resource = guess_resource(state_dict)
|
||||
if model_resource is None:
|
||||
return state_dict
|
||||
|
||||
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
||||
def guess_alpha(state_dict):
|
||||
for name, param in state_dict.items():
|
||||
if ".alpha" in name:
|
||||
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
||||
name_ = name.replace(".alpha", suffix)
|
||||
if name_ in state_dict:
|
||||
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||
lora_alpha = math.sqrt(lora_alpha)
|
||||
return lora_alpha
|
||||
|
||||
return 1
|
||||
|
||||
alpha = guess_alpha(state_dict)
|
||||
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
block_id, source_name = guess_block_id(name,model_resource)
|
||||
if alpha != 1:
|
||||
param *= alpha
|
||||
if source_name in rename_dict:
|
||||
target_name = rename_dict[source_name]
|
||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||
state_dict_[target_name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
|
||||
if model_resource == 'diffusers':
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
dim = 4
|
||||
if 'lora_A' in name:
|
||||
dim = 1
|
||||
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
d, r = state_dict_[name].shape
|
||||
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||
param[:d, :r] = state_dict_.pop(name)
|
||||
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||
param[3*d:, 3*r:] = mlp
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
concat_dim = 0
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
d, r = origin.shape
|
||||
# print(d, r)
|
||||
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
|
||||
class LoraMerger(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
|
||||
def forward(self, base_output, lora_outputs):
|
||||
norm_base_output = self.norm_base(base_output)
|
||||
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||
gate = self.activation(
|
||||
norm_base_output * self.weight_base \
|
||||
+ norm_lora_outputs * self.weight_lora \
|
||||
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||
)
|
||||
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||
return output
|
||||
|
||||
class FluxLoraPatcher(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, base_output, lora_outputs, name):
|
||||
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||
93
diffsynth/models/flux_text_encoder.py
Normal file
93
diffsynth/models/flux_text_encoder.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
|
||||
|
||||
class FluxTextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
hidden_states = embeds
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
return embeds, pooled_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxTextEncoder1StateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2(T5EncoderModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids):
|
||||
outputs = super().forward(input_ids=input_ids)
|
||||
prompt_emb = outputs.last_hidden_state
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder1StateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2StateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = state_dict
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,112 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||
super().__init__()
|
||||
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||
|
||||
self.use_quick_gelu = use_quick_gelu
|
||||
|
||||
def quickGELU(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward(self, hidden_states, attn_mask=None):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
if self.use_quick_gelu:
|
||||
hidden_states = self.quickGELU(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxTextEncoderClip(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids)
|
||||
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
hidden_states = embeds
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
return pooled_embeds, hidden_states
|
||||
@@ -1,43 +0,0 @@
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
|
||||
|
||||
class FluxTextEncoderT5(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(**{
|
||||
"architectures": [
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"classifier_dropout": 0.0,
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "gelu_new",
|
||||
"dropout_rate": 0.1,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": True,
|
||||
"is_gated_act": True,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": True,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": False,
|
||||
"transformers_version": "4.57.1",
|
||||
"use_cache": True,
|
||||
"vocab_size": 32128
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids):
|
||||
outputs = super().forward(input_ids=input_ids)
|
||||
prompt_emb = outputs.last_hidden_state
|
||||
return prompt_emb
|
||||
@@ -1,451 +1,303 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
||||
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class TileWorker:
|
||||
class FluxVAEEncoder(SD3VAEEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEDecoder(SD3VAEDecoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def mask(self, height, width, border_width):
|
||||
# Create a mask with shape (height, width).
|
||||
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / border_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
|
||||
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||
batch_size, channel, _, _ = model_input.shape
|
||||
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||
unfold_operator = torch.nn.Unfold(
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
model_input = unfold_operator(model_input)
|
||||
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||
|
||||
return model_input
|
||||
|
||||
|
||||
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||
# Call y=forward_fn(x) for each tile
|
||||
tile_num = model_input.shape[-1]
|
||||
model_output_stack = []
|
||||
|
||||
for tile_id in range(0, tile_num, tile_batch_size):
|
||||
|
||||
# process input
|
||||
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||
|
||||
# process output
|
||||
y = forward_fn(x)
|
||||
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||
model_output_stack.append(y)
|
||||
|
||||
model_output = torch.concat(model_output_stack, dim=-1)
|
||||
return model_output
|
||||
|
||||
|
||||
def io_scale(self, model_output, tile_size):
|
||||
# Determine the size modification happened in forward_fn
|
||||
# We only consider the same scale on height and width.
|
||||
io_scale = model_output.shape[2] / tile_size
|
||||
return io_scale
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"encoder.conv_in.bias": "conv_in.bias",
|
||||
"encoder.conv_in.weight": "conv_in.weight",
|
||||
"encoder.conv_out.bias": "conv_out.bias",
|
||||
"encoder.conv_out.weight": "conv_out.weight",
|
||||
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||
# The reversed function of tile
|
||||
mask = self.mask(tile_size, tile_size, border_width)
|
||||
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||
model_output = model_output * mask
|
||||
|
||||
fold_operator = torch.nn.Fold(
|
||||
output_size=(height, width),
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||
height, width = model_input.shape[2], model_input.shape[3]
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
|
||||
# tile
|
||||
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||
|
||||
# inference
|
||||
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||
|
||||
# resize
|
||||
io_scale = self.io_scale(model_output, tile_size)
|
||||
height, width = int(height*io_scale), int(width*io_scale)
|
||||
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||
border_width = int(border_width*io_scale)
|
||||
|
||||
# untile
|
||||
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||
|
||||
# Done!
|
||||
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||
return model_output
|
||||
|
||||
|
||||
class ConvAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
|
||||
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||
q = self.to_q(conv_input)
|
||||
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
|
||||
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
|
||||
k = self.to_k(conv_input)
|
||||
v = self.to_v(conv_input)
|
||||
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
|
||||
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||
hidden_states = self.to_out(conv_input)
|
||||
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
if use_conv_attention:
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
ConvAttention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
else:
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
x = hidden_states
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv1(x)
|
||||
if time_emb is not None:
|
||||
emb = self.nonlinearity(time_emb)
|
||||
emb = self.time_emb_proj(emb)[:, :, None, None]
|
||||
x = x + emb
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv2(x)
|
||||
if self.conv_shortcut is not None:
|
||||
hidden_states = self.conv_shortcut(hidden_states)
|
||||
hidden_states = hidden_states + x
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class UpSampler(torch.nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class DownSampler(torch.nn.Module):
|
||||
def __init__(self, channels, padding=1, extra_padding=False):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
|
||||
self.extra_padding = extra_padding
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
if self.extra_padding:
|
||||
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class FluxVAEDecoder(torch.nn.Module):
|
||||
def __init__(self, use_conv_attention=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxVAEEncoder(torch.nn.Module):
|
||||
def __init__(self, use_conv_attention=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states[:, :16]
|
||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"decoder.conv_in.bias": "conv_in.bias",
|
||||
"decoder.conv_in.weight": "conv_in.weight",
|
||||
"decoder.conv_out.bias": "conv_out.bias",
|
||||
"decoder.conv_out.weight": "conv_out.weight",
|
||||
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||
"decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -1,56 +0,0 @@
|
||||
import torch
|
||||
from .general_modules import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
if not isinstance(encoders, list):
|
||||
encoders = [encoders]
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
emb = []
|
||||
for encoder, value in zip(self.encoders, values):
|
||||
if value is not None:
|
||||
value = value.unsqueeze(0)
|
||||
emb.append(encoder(value, dtype))
|
||||
emb = torch.concat(emb, dim=0)
|
||||
return emb
|
||||
|
||||
|
||||
class SingleValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||
super().__init__()
|
||||
self.prefer_len = prefer_len
|
||||
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.prefer_value_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.positional_embedding = torch.nn.Parameter(
|
||||
torch.randn(self.prefer_len, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, value, dtype):
|
||||
value = value * 1000
|
||||
emb = self.prefer_proj(value).to(dtype)
|
||||
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
|
||||
learned_embeddings = base_embeddings + positional_embedding
|
||||
return learned_embeddings
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SingleValueEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SingleValueEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,146 +0,0 @@
|
||||
import torch, math
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
computation_device = None,
|
||||
align_dtype_to_timestep = False,
|
||||
):
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
if align_dtype_to_timestep:
|
||||
emb = emb.to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.computation_device = computation_device
|
||||
self.scale = scale
|
||||
self.align_dtype_to_timestep = align_dtype_to_timestep
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
computation_device=self.computation_device,
|
||||
scale=self.scale,
|
||||
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||
if diffusers_compatible_format:
|
||||
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||
else:
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if use_additional_t_cond:
|
||||
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
|
||||
|
||||
def forward(self, timestep, dtype, addition_t_cond=None):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
if addition_t_cond is not None:
|
||||
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||
addition_t_emb = addition_t_emb.to(dtype=dtype)
|
||||
time_emb = time_emb + addition_t_emb
|
||||
return time_emb
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False, dual=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.dual = dual
|
||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
elif self.dual:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||
norm_x = self.norm(x)
|
||||
x = norm_x * (1 + scale_msa) + shift_msa
|
||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
451
diffsynth/models/hunyuan_dit.py
Normal file
451
diffsynth/models/hunyuan_dit.py
Normal file
@@ -0,0 +1,451 @@
|
||||
from .attention import Attention
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
||||
super().__init__()
|
||||
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.rotary_emb_on_k = rotary_emb_on_k
|
||||
self.k_cache, self.v_cache = [], []
|
||||
|
||||
def reshape_for_broadcast(self, freqs_cis, x):
|
||||
ndim = x.ndim
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
|
||||
def rotate_half(self, x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
||||
xk_out = None
|
||||
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
||||
# norm
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# RoPE
|
||||
if self.rotary_emb_on_k:
|
||||
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
||||
else:
|
||||
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
||||
|
||||
if to_cache:
|
||||
self.k_cache.append(k)
|
||||
self.v_cache.append(v)
|
||||
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
||||
k = torch.concat([k] + self.k_cache, dim=2)
|
||||
v = torch.concat([v] + self.v_cache, dim=2)
|
||||
self.k_cache, self.v_cache = [], []
|
||||
return q, k, v
|
||||
|
||||
|
||||
class FP32_Layernorm(torch.nn.LayerNorm):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
||||
|
||||
|
||||
class FP32_SiLU(torch.nn.SiLU):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
||||
|
||||
|
||||
class HunyuanDiTFinalLayer(torch.nn.Module):
|
||||
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
||||
super().__init__()
|
||||
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def modulate(self, x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def forward(self, hidden_states, condition_emb):
|
||||
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
||||
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiTBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim=1408,
|
||||
condition_dim=1408,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.3637,
|
||||
text_dim=1024,
|
||||
skip_connection=False
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
||||
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
||||
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
||||
)
|
||||
if skip_connection:
|
||||
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
||||
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
||||
else:
|
||||
self.skip_norm, self.skip_linear = None, None
|
||||
|
||||
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
||||
# Long Skip Connection
|
||||
if self.skip_norm is not None and self.skip_linear is not None:
|
||||
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
hidden_states = self.skip_linear(hidden_states)
|
||||
|
||||
# Self-Attention
|
||||
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
||||
attn_input = self.norm1(hidden_states) + shift_msa
|
||||
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
||||
|
||||
# Cross-Attention
|
||||
attn_input = self.norm3(hidden_states)
|
||||
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
||||
|
||||
# FFN Layer
|
||||
mlp_input = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionPool(torch.nn.Module):
|
||||
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
||||
x, _ = torch.nn.functional.multi_head_attention_forward(
|
||||
query=x[:1], key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 2),
|
||||
in_chans=4,
|
||||
embed_dim=1408,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(t, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
class TimestepEmbedder(torch.nn.Module):
|
||||
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class HunyuanDiT(torch.nn.Module):
|
||||
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
||||
super().__init__()
|
||||
|
||||
# Embedders
|
||||
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
||||
self.t5_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
||||
)
|
||||
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
||||
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
||||
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
||||
self.timestep_embedder = TimestepEmbedder()
|
||||
self.extra_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.num_layers_down = num_layers_down
|
||||
self.num_layers_up = num_layers_up
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
||||
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
||||
)
|
||||
|
||||
# Output layers
|
||||
self.final_layer = HunyuanDiTFinalLayer()
|
||||
self.out_channels = out_channels
|
||||
|
||||
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
||||
text_emb_mask = text_emb_mask.bool()
|
||||
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
||||
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
||||
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
||||
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
||||
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
||||
return text_emb
|
||||
|
||||
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
||||
# Text embedding
|
||||
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
||||
|
||||
# Timestep embedding
|
||||
timestep_emb = self.timestep_embedder(timestep)
|
||||
|
||||
# Size embedding
|
||||
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
||||
size_emb = size_emb.view(-1, 6 * 256)
|
||||
|
||||
# Style embedding
|
||||
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
||||
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
||||
|
||||
return condition_emb
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
||||
|
||||
def build_mask(self, data, is_bound):
|
||||
_, _, H, W = data.shape
|
||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else h + 1,
|
||||
pad if is_bound[1] else H - h,
|
||||
pad if is_bound[2] else w + 1,
|
||||
pad if is_bound[3] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
||||
mask = rearrange(mask, "H W -> 1 H W")
|
||||
return mask
|
||||
|
||||
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
||||
B, C, H, W = hidden_states.shape
|
||||
|
||||
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = H - tile_size, H
|
||||
if w_ > W: w, w_ = W - tile_size, W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
||||
if residual is not None:
|
||||
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residual_batch = None
|
||||
|
||||
# Forward
|
||||
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
||||
|
||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
||||
weight[:, :, hl:hr, wl:wr] += mask
|
||||
values /= weight
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
to_cache=False,
|
||||
use_gradient_checkpointing=False,
|
||||
):
|
||||
# Embeddings
|
||||
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
||||
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
||||
|
||||
# Input
|
||||
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
||||
hidden_states = self.patch_embedder(hidden_states)
|
||||
|
||||
# Blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
if tiled:
|
||||
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
hidden_states = self.tiled_block_forward(
|
||||
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
||||
tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
|
||||
# Output
|
||||
hidden_states = self.final_layer(hidden_states, condition_emb)
|
||||
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
name_ = name
|
||||
name_ = name_.replace(".default_modulation.", ".modulation.")
|
||||
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
||||
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
||||
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
||||
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
||||
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
||||
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
||||
name_ = name_.replace(".q_proj.", ".to_q.")
|
||||
name_ = name_.replace(".out_proj.", ".to_out.")
|
||||
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
||||
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
||||
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
||||
name_ = name_.replace("pooler.", "t5_pooler.")
|
||||
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
||||
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
||||
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
||||
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
||||
if ".kv_proj." in name_:
|
||||
param_k = param[:param.shape[0]//2]
|
||||
param_v = param[param.shape[0]//2:]
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
||||
elif ".Wqkv." in name_:
|
||||
param_q = param[:param.shape[0]//3]
|
||||
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
||||
param_v = param[param.shape[0]//3*2:]
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
||||
elif "style_embedder" in name_:
|
||||
state_dict_[name_] = param.squeeze()
|
||||
else:
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
163
diffsynth/models/hunyuan_dit_text_encoder.py
Normal file
163
diffsynth/models/hunyuan_dit_text_encoder.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoder(BertModel):
|
||||
def __init__(self):
|
||||
config = BertConfig(
|
||||
_name_or_path = "",
|
||||
architectures = ["BertModel"],
|
||||
attention_probs_dropout_prob = 0.1,
|
||||
bos_token_id = 0,
|
||||
classifier_dropout = None,
|
||||
directionality = "bidi",
|
||||
eos_token_id = 2,
|
||||
hidden_act = "gelu",
|
||||
hidden_dropout_prob = 0.1,
|
||||
hidden_size = 1024,
|
||||
initializer_range = 0.02,
|
||||
intermediate_size = 4096,
|
||||
layer_norm_eps = 1e-12,
|
||||
max_position_embeddings = 512,
|
||||
model_type = "bert",
|
||||
num_attention_heads = 16,
|
||||
num_hidden_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
pooler_fc_size = 768,
|
||||
pooler_num_attention_heads = 12,
|
||||
pooler_num_fc_layers = 3,
|
||||
pooler_size_per_head = 128,
|
||||
pooler_type = "first_token_transform",
|
||||
position_embedding_type = "absolute",
|
||||
torch_dtype = "float32",
|
||||
transformers_version = "4.37.2",
|
||||
type_vocab_size = 2,
|
||||
use_cache = True,
|
||||
vocab_size = 47020
|
||||
)
|
||||
super().__init__(config, add_pooling_layer=False)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
input_shape = input_ids.size()
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
past_key_values_length=0,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
all_hidden_states = encoder_outputs.hidden_states
|
||||
prompt_emb = all_hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(
|
||||
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
||||
architectures = ["MT5ForConditionalGeneration"],
|
||||
classifier_dropout = 0.0,
|
||||
d_ff = 5120,
|
||||
d_kv = 64,
|
||||
d_model = 2048,
|
||||
decoder_start_token_id = 0,
|
||||
dense_act_fn = "gelu_new",
|
||||
dropout_rate = 0.1,
|
||||
eos_token_id = 1,
|
||||
feed_forward_proj = "gated-gelu",
|
||||
initializer_factor = 1.0,
|
||||
is_encoder_decoder = True,
|
||||
is_gated_act = True,
|
||||
layer_norm_epsilon = 1e-06,
|
||||
model_type = "t5",
|
||||
num_decoder_layers = 24,
|
||||
num_heads = 32,
|
||||
num_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
relative_attention_max_distance = 128,
|
||||
relative_attention_num_buckets = 32,
|
||||
tie_word_embeddings = False,
|
||||
tokenizer_class = "T5Tokenizer",
|
||||
transformers_version = "4.37.2",
|
||||
use_cache = True,
|
||||
vocab_size = 250112
|
||||
)
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_emb = outputs.hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
||||
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
1552
diffsynth/models/kolors_text_encoder.py
Normal file
1552
diffsynth/models/kolors_text_encoder.py
Normal file
File diff suppressed because one or more lines are too long
@@ -1,902 +0,0 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.amp as amp
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
class RMSNorm_FP32(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def broadcat(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatentation"
|
||||
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
||||
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
||||
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
cp_split_hw=None
|
||||
):
|
||||
"""Rotary positional embedding for 3D
|
||||
Reference : https://blog.eleuther.ai/rotary-embeddings/
|
||||
Paper: https://arxiv.org/pdf/2104.09864.pdf
|
||||
Args:
|
||||
dim: Dimension of embedding
|
||||
base: Base value for exponential
|
||||
"""
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
|
||||
self.cp_split_hw = cp_split_hw
|
||||
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
|
||||
self.base = 10000
|
||||
self.freqs_dict = {}
|
||||
|
||||
def register_grid_size(self, grid_size):
|
||||
if grid_size not in self.freqs_dict:
|
||||
self.freqs_dict.update({
|
||||
grid_size: self.precompute_freqs_cis_3d(grid_size)
|
||||
})
|
||||
|
||||
def precompute_freqs_cis_3d(self, grid_size):
|
||||
num_frames, height, width = grid_size
|
||||
dim_t = self.head_dim - 4 * (self.head_dim // 6)
|
||||
dim_h = 2 * (self.head_dim // 6)
|
||||
dim_w = 2 * (self.head_dim // 6)
|
||||
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
|
||||
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
||||
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
||||
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
||||
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
# (T H W D)
|
||||
freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# with torch.no_grad():
|
||||
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
|
||||
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
|
||||
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
|
||||
|
||||
return freqs
|
||||
|
||||
def forward(self, q, k, grid_size):
|
||||
"""3D RoPE.
|
||||
|
||||
Args:
|
||||
query: [B, head, seq, head_dim]
|
||||
key: [B, head, seq, head_dim]
|
||||
Returns:
|
||||
query and key with the same shape as input.
|
||||
"""
|
||||
|
||||
if grid_size not in self.freqs_dict:
|
||||
self.register_grid_size(grid_size)
|
||||
|
||||
freqs_cis = self.freqs_dict[grid_size].to(q.device)
|
||||
q_, k_ = q.float(), k.float()
|
||||
freqs_cis = freqs_cis.float().to(q.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
q_ = (q_ * cos) + (rotate_half(q_) * sin)
|
||||
k_ = (k_ * cos) + (rotate_half(k_) * sin)
|
||||
|
||||
return q_.type_as(q), k_.type_as(k)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = False,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params: dict = None,
|
||||
cp_split_hw: Optional[List[int]] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flashattn3 = enable_flashattn3
|
||||
self.enable_flashattn2 = enable_flashattn2
|
||||
self.enable_xformers = enable_xformers
|
||||
self.enable_bsa = enable_bsa
|
||||
self.bsa_params = bsa_params
|
||||
self.cp_split_hw = cp_split_hw
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.rope_3d = RotaryPositionalEmbedding(
|
||||
self.head_dim,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
|
||||
def _process_attn(self, q, k, v, shape):
|
||||
q = rearrange(q, "B H S D -> B S (H D)")
|
||||
k = rearrange(k, "B H S D -> B S (H D)")
|
||||
v = rearrange(v, "B H S D -> B S (H D)")
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if return_kv:
|
||||
k_cache, v_cache = k.clone(), v.clone()
|
||||
|
||||
q, k = self.rope_3d(q, k, shape)
|
||||
|
||||
# cond mode
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||
# process the condition tokens
|
||||
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
|
||||
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
|
||||
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
|
||||
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
|
||||
# process the noise tokens
|
||||
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
|
||||
x_noise = self._process_attn(q_noise, k, v, shape)
|
||||
# merge x_cond and x_noise
|
||||
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
|
||||
else:
|
||||
x = self._process_attn(q, k, v, shape)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||
x = self.proj(x)
|
||||
|
||||
if return_kv:
|
||||
return x, (k_cache, v_cache)
|
||||
else:
|
||||
return x
|
||||
|
||||
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
T, H, W = shape
|
||||
k_cache, v_cache = kv_cache
|
||||
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
|
||||
if k_cache.shape[0] == 1:
|
||||
k_cache = k_cache.repeat(B, 1, 1, 1)
|
||||
v_cache = v_cache.repeat(B, 1, 1, 1)
|
||||
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
k_full = torch.cat([k_cache, k], dim=2).contiguous()
|
||||
v_full = torch.cat([v_cache, v], dim=2).contiguous()
|
||||
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
|
||||
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
|
||||
q = q_padding[:, :, -N:].contiguous()
|
||||
|
||||
x = self._process_attn(q, k_full, v_full, shape)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
|
||||
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
enable_flashattn3=False,
|
||||
enable_flashattn2=False,
|
||||
enable_xformers=False,
|
||||
):
|
||||
super(MultiHeadCrossAttention, self).__init__()
|
||||
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = nn.Linear(dim, dim)
|
||||
self.kv_linear = nn.Linear(dim, dim * 2)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
|
||||
|
||||
self.enable_flashattn3 = enable_flashattn3
|
||||
self.enable_flashattn2 = enable_flashattn2
|
||||
self.enable_xformers = enable_xformers
|
||||
|
||||
def _process_cross_attn(self, x, cond, kv_seqlen):
|
||||
B, N, C = x.shape
|
||||
assert C == self.dim and cond.shape[2] == self.dim
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
q = rearrange(q, "B S H D -> B S (H D)")
|
||||
k = rearrange(k, "B S H D -> B S (H D)")
|
||||
v = rearrange(v, "B S H D -> B S (H D)")
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
|
||||
x = x.view(B, -1, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
|
||||
"""
|
||||
x: [B, N, C]
|
||||
cond: [B, M, C]
|
||||
"""
|
||||
if num_cond_latents is None or num_cond_latents == 0:
|
||||
return self._process_cross_attn(x, cond, kv_seqlen)
|
||||
else:
|
||||
B, N, C = x.shape
|
||||
if num_cond_latents is not None and num_cond_latents > 0:
|
||||
assert shape is not None, "SHOULD pass in the shape"
|
||||
num_cond_latents_thw = num_cond_latents * (N // shape[0])
|
||||
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
|
||||
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
|
||||
output = torch.cat([
|
||||
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
|
||||
output_noise
|
||||
], dim=1).contiguous()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class LayerNorm_FP32(nn.LayerNorm):
|
||||
def __init__(self, dim, eps, elementwise_affine):
|
||||
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
origin_dtype = inputs.dtype
|
||||
out = F.layer_norm(
|
||||
inputs.float(),
|
||||
self.normalized_shape,
|
||||
None if self.weight is None else self.weight.float(),
|
||||
None if self.bias is None else self.bias.float() ,
|
||||
self.eps
|
||||
).to(origin_dtype)
|
||||
return out
|
||||
|
||||
|
||||
def modulate_fp32(norm_func, x, shift, scale):
|
||||
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
|
||||
# ensure the modulation params be fp32
|
||||
assert shift.dtype == torch.float32, scale.dtype == torch.float32
|
||||
dtype = x.dtype
|
||||
x = norm_func(x.to(torch.float32))
|
||||
x = x * (scale + 1) + shift
|
||||
x = x.to(dtype)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer_FP32(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_patch = num_patch
|
||||
self.out_channels = out_channels
|
||||
self.adaln_tembed_dim = adaln_tembed_dim
|
||||
|
||||
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, t, latent_shape):
|
||||
# timestep shape: [B, T, C]
|
||||
assert t.dtype == torch.float32
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardSwiGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.dim = dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, t_embed_dim, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.t_embed_dim = t_embed_dim
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||
freqs = freqs.to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
if t_freq.dtype != dtype:
|
||||
t_freq = t_freq.to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.y_proj = nn.Sequential(
|
||||
nn.Linear(in_channels, hidden_size, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
B, _, N, C = caption.shape
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class PatchEmbed3D(nn.Module):
|
||||
"""Video to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: (2,4,4).
|
||||
in_chans (int): Number of input video channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 4, 4),
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.flatten = flatten
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, D, H, W = x.size()
|
||||
if W % self.patch_size[2] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||
if H % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||
if D % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||
|
||||
B, C, T, H, W = x.shape
|
||||
x = self.proj(x) # (B C T H W)
|
||||
if self.norm is not None:
|
||||
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
class LongCatSingleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: int,
|
||||
adaln_tembed_dim: int,
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = False,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params=None,
|
||||
cp_split_hw=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# scale and gate modulation
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
|
||||
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
|
||||
|
||||
self.attn = Attention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
enable_bsa=enable_bsa,
|
||||
bsa_params=bsa_params,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
)
|
||||
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
|
||||
|
||||
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
|
||||
"""
|
||||
x: [B, N, C]
|
||||
y: [1, N_valid_tokens, C]
|
||||
t: [B, T, C_t]
|
||||
y_seqlen: [B]; type of a list
|
||||
latent_shape: latent shape of a single item
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
|
||||
# self attn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
|
||||
|
||||
if kv_cache is not None:
|
||||
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
|
||||
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
|
||||
else:
|
||||
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
|
||||
|
||||
if return_kv:
|
||||
x_s, kv_cache = attn_outputs
|
||||
else:
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
# cross attn
|
||||
if not skip_crs_attn:
|
||||
if kv_cache is not None:
|
||||
num_cond_latents = None
|
||||
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
|
||||
|
||||
# ffn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
if return_kv:
|
||||
return x, kv_cache
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
hidden_size: int = 4096,
|
||||
depth: int = 48,
|
||||
num_heads: int = 32,
|
||||
caption_channels: int = 4096,
|
||||
mlp_ratio: int = 4,
|
||||
adaln_tembed_dim: int = 512,
|
||||
frequency_embedding_size: int = 256,
|
||||
# default params
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
# attention config
|
||||
enable_flashattn3: bool = False,
|
||||
enable_flashattn2: bool = True,
|
||||
enable_xformers: bool = False,
|
||||
enable_bsa: bool = False,
|
||||
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
|
||||
cp_split_hw: Optional[List[int]] = [1, 1],
|
||||
text_tokens_zero_pad: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.cp_split_hw = cp_split_hw
|
||||
|
||||
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels,
|
||||
hidden_size=hidden_size,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
LongCatSingleStreamBlock(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
adaln_tembed_dim=adaln_tembed_dim,
|
||||
enable_flashattn3=enable_flashattn3,
|
||||
enable_flashattn2=enable_flashattn2,
|
||||
enable_xformers=enable_xformers,
|
||||
enable_bsa=enable_bsa,
|
||||
bsa_params=bsa_params,
|
||||
cp_split_hw=cp_split_hw
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer_FP32(
|
||||
hidden_size,
|
||||
np.prod(self.patch_size),
|
||||
out_channels,
|
||||
adaln_tembed_dim,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.text_tokens_zero_pad = text_tokens_zero_pad
|
||||
|
||||
self.lora_dict = {}
|
||||
self.active_loras = []
|
||||
|
||||
def enable_loras(self, lora_key_list=[]):
|
||||
self.disable_all_loras()
|
||||
|
||||
module_loras = {} # {module_name: [lora1, lora2, ...]}
|
||||
model_device = next(self.parameters()).device
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
|
||||
for lora_key in lora_key_list:
|
||||
if lora_key in self.lora_dict:
|
||||
for lora in self.lora_dict[lora_key].loras:
|
||||
lora.to(model_device, dtype=model_dtype, non_blocking=True)
|
||||
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
|
||||
if module_name not in module_loras:
|
||||
module_loras[module_name] = []
|
||||
module_loras[module_name].append(lora)
|
||||
self.active_loras.append(lora_key)
|
||||
|
||||
for module_name, loras in module_loras.items():
|
||||
module = self._get_module_by_name(module_name)
|
||||
if not hasattr(module, 'org_forward'):
|
||||
module.org_forward = module.forward
|
||||
module.forward = self._create_multi_lora_forward(module, loras)
|
||||
|
||||
def _create_multi_lora_forward(self, module, loras):
|
||||
def multi_lora_forward(x, *args, **kwargs):
|
||||
weight_dtype = x.dtype
|
||||
org_output = module.org_forward(x, *args, **kwargs)
|
||||
|
||||
total_lora_output = 0
|
||||
for lora in loras:
|
||||
if lora.use_lora:
|
||||
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
|
||||
lx = lora.lora_up(lx)
|
||||
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
|
||||
total_lora_output += lora_output
|
||||
|
||||
return org_output + total_lora_output
|
||||
|
||||
return multi_lora_forward
|
||||
|
||||
def _get_module_by_name(self, module_name):
|
||||
try:
|
||||
module = self
|
||||
for part in module_name.split('.'):
|
||||
module = getattr(module, part)
|
||||
return module
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
|
||||
|
||||
def disable_all_loras(self):
|
||||
for name, module in self.named_modules():
|
||||
if hasattr(module, 'org_forward'):
|
||||
module.forward = module.org_forward
|
||||
delattr(module, 'org_forward')
|
||||
|
||||
for lora_key, lora_network in self.lora_dict.items():
|
||||
for lora in lora_network.loras:
|
||||
lora.to("cpu")
|
||||
|
||||
self.active_loras.clear()
|
||||
|
||||
def enable_bsa(self,):
|
||||
for block in self.blocks:
|
||||
block.attn.enable_bsa = True
|
||||
|
||||
def disable_bsa(self,):
|
||||
for block in self.blocks:
|
||||
block.attn.enable_bsa = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
num_cond_latents=0,
|
||||
return_kv=False,
|
||||
kv_cache_dict={},
|
||||
skip_crs_attn=False,
|
||||
offload_kv_cache=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
|
||||
B, _, T, H, W = hidden_states.shape
|
||||
|
||||
N_t = T // self.patch_size[0]
|
||||
N_h = H // self.patch_size[1]
|
||||
N_w = W // self.patch_size[2]
|
||||
|
||||
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
|
||||
|
||||
# expand the shape of timestep from [B] to [B, T]
|
||||
if len(timestep.shape) == 1:
|
||||
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
|
||||
timestep[:, :num_cond_latents] = 0
|
||||
|
||||
dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
|
||||
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
|
||||
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
|
||||
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
|
||||
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
|
||||
else:
|
||||
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
|
||||
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
|
||||
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
|
||||
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
|
||||
|
||||
# blocks
|
||||
kv_cache_dict_ret = {}
|
||||
for i, block in enumerate(self.blocks):
|
||||
block_outputs = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=hidden_states,
|
||||
y=encoder_hidden_states,
|
||||
t=t,
|
||||
y_seqlen=y_seqlens,
|
||||
latent_shape=(N_t, N_h, N_w),
|
||||
num_cond_latents=num_cond_latents,
|
||||
return_kv=return_kv,
|
||||
kv_cache=kv_cache_dict.get(i, None),
|
||||
skip_crs_attn=skip_crs_attn,
|
||||
)
|
||||
|
||||
if return_kv:
|
||||
hidden_states, kv_cache = block_outputs
|
||||
if offload_kv_cache:
|
||||
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
|
||||
else:
|
||||
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
|
||||
else:
|
||||
hidden_states = block_outputs
|
||||
|
||||
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
|
||||
|
||||
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
|
||||
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
|
||||
|
||||
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
if return_kv:
|
||||
return hidden_states, kv_cache_dict_ret
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unpatchify(self, x, N_t, N_h, N_w):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): of shape [B, N, C]
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): of shape [B, C_out, T, H, W]
|
||||
"""
|
||||
T_p, H_p, W_p = self.patch_size
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
|
||||
N_t=N_t,
|
||||
N_h=N_h,
|
||||
N_w=N_w,
|
||||
T_p=T_p,
|
||||
H_p=H_p,
|
||||
W_p=W_p,
|
||||
C_out=self.out_channels,
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return LongCatVideoTransformer3DModelDictConverter()
|
||||
|
||||
|
||||
class LongCatVideoTransformer3DModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
252
diffsynth/models/lora.py
Normal file
252
diffsynth/models/lora.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNet
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sd3_dit import SD3DiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
|
||||
class LoRAFromCivitai:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = []
|
||||
self.lora_prefix = []
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
for key in state_dict:
|
||||
if ".lora_up" in key:
|
||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||
|
||||
|
||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||
for special_key in self.special_keys:
|
||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
target_name = target_name[len(lora_prefix):]
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||
if model_resource == "diffusers":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||
elif model_resource == "civitai":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
for model_resource in ["diffusers", "civitai"]:
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
if name not in state_dict_model:
|
||||
break
|
||||
else:
|
||||
return lora_prefix, model_resource
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
}
|
||||
|
||||
|
||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
||||
}
|
||||
|
||||
|
||||
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [FluxDiT, FluxDiT]
|
||||
self.lora_prefix = ["lora_unet_", "transformer."]
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {
|
||||
"single.blocks": "single_blocks",
|
||||
"double.blocks": "double_blocks",
|
||||
"img.attn": "img_attn",
|
||||
"img.mlp": "img_mlp",
|
||||
"img.mod": "img_mod",
|
||||
"txt.attn": "txt_attn",
|
||||
"txt.mlp": "txt_mlp",
|
||||
"txt.mod": "txt_mod",
|
||||
}
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT]
|
||||
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||
device, torch_dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, torch_dtype = param.device, param.dtype
|
||||
break
|
||||
return device, torch_dtype
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
if target_name not in target_state_dict:
|
||||
return {}
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for model_class in self.supported_model_classes:
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora_) > 0:
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,371 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Protocol, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class VideoPixelShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
fps: float
|
||||
|
||||
|
||||
class SpatioTemporalScaleFactors(NamedTuple):
|
||||
"""
|
||||
Describes the spatiotemporal downscaling between decoded video space and
|
||||
the corresponding VAE latent grid.
|
||||
"""
|
||||
|
||||
time: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "SpatioTemporalScaleFactors":
|
||||
return cls(time=8, width=32, height=32)
|
||||
|
||||
|
||||
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
||||
|
||||
|
||||
class VideoLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing video in VAE latent space.
|
||||
The latent representation is a 5D tensor with dimensions ordered as
|
||||
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
||||
are downscaled relative to pixel space according to the VAE's scale factors.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
||||
return VideoLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
height=shape[3],
|
||||
width=shape[4],
|
||||
)
|
||||
|
||||
def mask_shape(self) -> "VideoLatentShape":
|
||||
return self._replace(channels=1)
|
||||
|
||||
@staticmethod
|
||||
def from_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
latent_channels: int = 128,
|
||||
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
||||
) -> "VideoLatentShape":
|
||||
frames = (shape.frames - 1) // scale_factors[0] + 1
|
||||
height = shape.height // scale_factors[1]
|
||||
width = shape.width // scale_factors[2]
|
||||
|
||||
return VideoLatentShape(
|
||||
batch=shape.batch,
|
||||
channels=latent_channels,
|
||||
frames=frames,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
||||
return self._replace(
|
||||
channels=3,
|
||||
frames=(self.frames - 1) * scale_factors.time + 1,
|
||||
height=self.height * scale_factors.height,
|
||||
width=self.width * scale_factors.width,
|
||||
)
|
||||
|
||||
|
||||
class AudioLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
||||
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
mel_bins: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
||||
|
||||
def mask_shape(self) -> "AudioLatentShape":
|
||||
return self._replace(channels=1, mel_bins=1)
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
||||
return AudioLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
mel_bins=shape[3],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_duration(
|
||||
batch: int,
|
||||
duration: float,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
||||
|
||||
return AudioLatentShape(
|
||||
batch=batch,
|
||||
channels=channels,
|
||||
frames=round(duration * latents_per_second),
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_video_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
return AudioLatentShape.from_duration(
|
||||
batch=shape.batch,
|
||||
duration=float(shape.frames) / float(shape.fps),
|
||||
channels=channels,
|
||||
mel_bins=mel_bins,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=hop_length,
|
||||
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LatentState:
|
||||
"""
|
||||
State of latents during the diffusion denoising process.
|
||||
Attributes:
|
||||
latent: The current noisy latent tensor being denoised.
|
||||
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
||||
positions: Positional indices for each latent element, used for positional embeddings.
|
||||
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
||||
"""
|
||||
|
||||
latent: torch.Tensor
|
||||
denoise_mask: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
clean_latent: torch.Tensor
|
||||
|
||||
def clone(self) -> "LatentState":
|
||||
return LatentState(
|
||||
latent=self.latent.clone(),
|
||||
denoise_mask=self.denoise_mask.clone(),
|
||||
positions=self.positions.clone(),
|
||||
clean_latent=self.clean_latent.clone(),
|
||||
)
|
||||
|
||||
|
||||
class NormType(Enum):
|
||||
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||
|
||||
GROUP = "group"
|
||||
PIXEL = "pixel"
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension along which to compute the RMS (typically channels).
|
||||
eps: Small constant added for numerical stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply RMS normalization along the configured dimension.
|
||||
"""
|
||||
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
# Normalize by the root-mean-square (RMS).
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
def build_normalization_layer(
|
||||
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create a normalization layer based on the normalization type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
num_groups: Number of groups for group normalization
|
||||
normtype: Type of normalization: "group" or "pixel"
|
||||
Returns:
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if normtype == NormType.PIXEL:
|
||||
return PixelNorm(dim=1, eps=1e-6)
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
||||
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
||||
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
||||
shape and forwards `weight` and `eps`.
|
||||
"""
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Modality:
|
||||
"""
|
||||
Input data for a single modality (video or audio) in the transformer.
|
||||
Bundles the latent tokens, timestep embeddings, positional information,
|
||||
and text conditioning context for processing by the diffusion transformer.
|
||||
"""
|
||||
|
||||
latent: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||
positions: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
||||
context: torch.Tensor
|
||||
enabled: bool = True
|
||||
context_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
def to_denoised(
|
||||
sample: torch.Tensor,
|
||||
velocity: torch.Tensor,
|
||||
sigma: float | torch.Tensor,
|
||||
calc_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the sample and its denoising velocity to denoised sample.
|
||||
Returns:
|
||||
Denoised sample
|
||||
"""
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.to(calc_dtype)
|
||||
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
||||
|
||||
|
||||
|
||||
class Patchifier(Protocol):
|
||||
"""
|
||||
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
||||
"""
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
"""
|
||||
Convert latent tensors into flattened patch tokens.
|
||||
Args:
|
||||
latents: Latent tensor to patchify.
|
||||
Returns:
|
||||
Flattened patch tokens tensor.
|
||||
"""
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
||||
Args:
|
||||
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
||||
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
||||
VideoLatentShape.
|
||||
Returns:
|
||||
Dense latent tensor restored from the flattened representation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
...
|
||||
"""
|
||||
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
||||
"""
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
device: torch.device | None = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
"""
|
||||
Compute metadata describing where each latent patch resides within the
|
||||
grid specified by `output_shape`.
|
||||
Args:
|
||||
output_shape: Target grid layout for the patches.
|
||||
device: Target device for the returned tensor.
|
||||
Returns:
|
||||
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||
"""
|
||||
|
||||
|
||||
def get_pixel_coords(
|
||||
latent_coords: torch.Tensor,
|
||||
scale_factors: SpatioTemporalScaleFactors,
|
||||
causal_fix: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||
Args:
|
||||
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||
per axis.
|
||||
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||
that treat frame zero differently still yield non-negative timestamps.
|
||||
"""
|
||||
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||
broadcast_shape = [1] * latent_coords.ndim
|
||||
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||
|
||||
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||
pixel_coords = latent_coords * scale_tensor
|
||||
|
||||
if causal_fix:
|
||||
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
|
||||
return pixel_coords
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,366 +0,0 @@
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||
FeedForward)
|
||||
from .ltx2_common import rms_norm
|
||||
|
||||
|
||||
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
|
||||
def __init__(self):
|
||||
config = Gemma3Config(
|
||||
**{
|
||||
"architectures": ["Gemma3ForConditionalGeneration"],
|
||||
"boi_token_index": 255999,
|
||||
"dtype": "bfloat16",
|
||||
"eoi_token_index": 256000,
|
||||
"eos_token_id": [1, 106],
|
||||
"image_token_index": 262144,
|
||||
"initializer_range": 0.02,
|
||||
"mm_tokens_per_image": 256,
|
||||
"model_type": "gemma3",
|
||||
"text_config": {
|
||||
"_sliding_window_pattern": 6,
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": None,
|
||||
"cache_implementation": "hybrid",
|
||||
"dtype": "bfloat16",
|
||||
"final_logit_softcapping": None,
|
||||
"head_dim": 256,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 3840,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 15360,
|
||||
"layer_types": [
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
|
||||
],
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "gemma3_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 48,
|
||||
"num_key_value_heads": 8,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_local_base_freq": 10000,
|
||||
"rope_scaling": {
|
||||
"factor": 8.0,
|
||||
"rope_type": "linear"
|
||||
},
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": 1024,
|
||||
"sliding_window_pattern": 6,
|
||||
"use_bidirectional_attention": False,
|
||||
"use_cache": True,
|
||||
"vocab_size": 262208
|
||||
},
|
||||
"transformers_version": "4.57.3",
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 896,
|
||||
"intermediate_size": 4304,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"vision_use_head": False
|
||||
}
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
class LTXVGemmaTokenizer:
|
||||
"""
|
||||
Tokenizer wrapper for Gemma models compatible with LTXV processes.
|
||||
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
|
||||
ensuring correct settings and output formatting for downstream consumption.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_path: str, max_length: int = 1024):
|
||||
"""
|
||||
Initialize the tokenizer.
|
||||
Args:
|
||||
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
|
||||
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, local_files_only=True, model_max_length=max_length
|
||||
)
|
||||
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
|
||||
self.tokenizer.padding_side = "left"
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
|
||||
"""
|
||||
Tokenize the given text and return token IDs and attention weights.
|
||||
Args:
|
||||
text (str): The input string to tokenize.
|
||||
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
|
||||
If False (default), omits the indices.
|
||||
Returns:
|
||||
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
|
||||
A dictionary with a "gemma" key mapping to:
|
||||
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
|
||||
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
|
||||
Example:
|
||||
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
|
||||
>>> tokenizer.tokenize_with_weights("hello world")
|
||||
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
|
||||
"""
|
||||
text = text.strip()
|
||||
encoded = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = encoded.input_ids
|
||||
attention_mask = encoded.attention_mask
|
||||
tuples = [
|
||||
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
|
||||
]
|
||||
out = {"gemma": tuples}
|
||||
|
||||
if not return_word_ids:
|
||||
# Return only (token_id, attention_mask) pairs, omitting token position
|
||||
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
"""
|
||||
Feature extractor module for Gemma models.
|
||||
This module applies a single linear projection to the input tensor.
|
||||
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||
Attributes:
|
||||
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the GemmaFeaturesExtractorProjLinear module.
|
||||
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||
"""
|
||||
super().__init__()
|
||||
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the feature extractor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||
"""
|
||||
return self.aggregate_embed(x)
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dim_out=dim,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pe: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
|
||||
# 1. Normalization Before Self-Attention
|
||||
norm_hidden_states = rms_norm(hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 2. Self-Attention
|
||||
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Normalization before Feed-Forward
|
||||
norm_hidden_states = rms_norm(hidden_states)
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(torch.nn.Module):
|
||||
"""
|
||||
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
|
||||
layers, and register usage.
|
||||
Args:
|
||||
attention_head_dim (int): Dimension of each attention head (default=128).
|
||||
num_attention_heads (int): Number of attention heads (default=30).
|
||||
num_layers (int): Number of transformer layers (default=2).
|
||||
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
|
||||
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
|
||||
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
|
||||
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
|
||||
register replacement. (default=128)
|
||||
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
|
||||
double_precision_rope (bool): Use double precision rope calculation (default=False).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 30,
|
||||
num_layers: int = 2,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list[int] | None = [4096],
|
||||
causal_temporal_positioning: bool = False,
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = (
|
||||
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
|
||||
)
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
_BasicTransformerBlock1D(
|
||||
dim=self.inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = torch.nn.Parameter(
|
||||
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||
)
|
||||
|
||||
def _replace_padded_with_learnable_registers(
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
|
||||
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
|
||||
f"{self.num_learnable_registers}."
|
||||
)
|
||||
|
||||
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
|
||||
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
|
||||
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
|
||||
|
||||
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||
|
||||
attention_mask = torch.full_like(
|
||||
attention_mask,
|
||||
0.0,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of Embeddings1DConnector.
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
|
||||
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
|
||||
"""
|
||||
if self.num_learnable_registers:
|
||||
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
|
||||
|
||||
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
||||
freqs_cis = precompute_freqs_cis(
|
||||
indices_grid=indices_grid,
|
||||
dim=self.inner_dim,
|
||||
out_dtype=hidden_states.dtype,
|
||||
theta=self.positional_embedding_theta,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
rope_type=self.rope_type,
|
||||
freq_grid_generator=freq_grid_generator,
|
||||
)
|
||||
|
||||
for block in self.transformer_1d_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
|
||||
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
@@ -1,313 +0,0 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from .ltx2_video_vae import LTX2VideoEncoder
|
||||
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
"""
|
||||
N-dimensional pixel shuffle operation for upsampling tensors.
|
||||
Args:
|
||||
dims (int): Number of dimensions to apply pixel shuffle to.
|
||||
- 1: Temporal (e.g., frames)
|
||||
- 2: Spatial (e.g., height and width)
|
||||
- 3: Spatiotemporal (e.g., depth, height, width)
|
||||
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
|
||||
For dims=1, only the first value is used.
|
||||
For dims=2, the first two values are used.
|
||||
For dims=3, all three values are used.
|
||||
The input tensor is rearranged so that the channel dimension is split into
|
||||
smaller channels and upscaling factors, and the upscaling factors are moved
|
||||
into the corresponding spatial/temporal dimensions.
|
||||
Note:
|
||||
This operation is equivalent to the patchifier operation in for the models. Consider
|
||||
using this class instead.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
|
||||
super().__init__()
|
||||
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.dims == 3:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
p3=self.upscale_factors[2],
|
||||
)
|
||||
elif self.dims == 2:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
)
|
||||
elif self.dims == 1:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1) f h w -> b c (f p1) h w",
|
||||
p1=self.upscale_factors[0],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dims: {self.dims}")
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
"""
|
||||
Residual block with two convolutional layers, group normalization, and SiLU activation.
|
||||
Args:
|
||||
channels (int): Number of input and output channels.
|
||||
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
|
||||
if not specified.
|
||||
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class BlurDownsample(torch.nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||
super().__init__()
|
||||
assert dims in (2, 3)
|
||||
assert isinstance(stride, int)
|
||||
assert stride >= 1
|
||||
assert kernel_size >= 3
|
||||
assert kernel_size % 2 == 1
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||
# The 2D kernel is constructed as the outer product and normalized.
|
||||
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
if self.dims == 2:
|
||||
return self._apply_2d(x)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, _, f, _, _ = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self._apply_2d(x)
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||
return x
|
||||
|
||||
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
|
||||
c = x2d.shape[1]
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
return x2d
|
||||
|
||||
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||
if float(scale) not in mapping:
|
||||
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
|
||||
return mapping[float(scale)]
|
||||
|
||||
|
||||
class SpatialRationalResampler(torch.nn.Module):
|
||||
"""
|
||||
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||
Args:
|
||||
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
||||
scale (`float`): Spatial scaling factor. Supported values are:
|
||||
- 0.75: Downsample by 3/4 (reduce spatial size)
|
||||
- 1.5: Upsample by 3/2 (increase spatial size)
|
||||
- 2.0: Upsample by 2x (double spatial size)
|
||||
- 4.0: Upsample by 4x (quadruple spatial size)
|
||||
Any other value will raise a ValueError.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int, scale: float):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.num, self.den = _rational_for_scale(self.scale)
|
||||
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, _, f, _, _ = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2LatentUpsampler(torch.nn.Module):
|
||||
"""
|
||||
Model to upsample VAE latents spatially and/or temporally.
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input latent
|
||||
mid_channels (`int`): Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||
spatial_scale (`float`): Scale factor for spatial upsampling
|
||||
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
spatial_scale: float = 2.0,
|
||||
rational_resampler: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.rational_resampler = rational_resampler
|
||||
|
||||
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||
b, _, f, _, _ = latent.shape
|
||||
|
||||
if self.dims == 2:
|
||||
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.upsampler(x)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
else:
|
||||
x = self.initial_conv(latent)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.temporal_upsample:
|
||||
x = self.upsampler(x)
|
||||
# remove the first frame after upsampling.
|
||||
# This is done because the first frame encodes one pixel frame.
|
||||
x = x[:, :, 1:, :, :]
|
||||
elif isinstance(self.upsampler, SpatialRationalResampler):
|
||||
x = self.upsampler(x)
|
||||
else:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.upsampler(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
|
||||
"""
|
||||
Apply upsampling to the latent representation using the provided upsampler,
|
||||
with normalization and un-normalization based on the video encoder's per-channel statistics.
|
||||
Args:
|
||||
latent: Input latent tensor of shape [B, C, F, H, W].
|
||||
video_encoder: VideoEncoder with per_channel_statistics for normalization.
|
||||
upsampler: LTX2LatentUpsampler module to perform upsampling.
|
||||
Returns:
|
||||
torch.Tensor: Upsampled and re-normalized latent tensor.
|
||||
"""
|
||||
latent = video_encoder.per_channel_statistics.un_normalize(latent)
|
||||
latent = upsampler(latent)
|
||||
latent = video_encoder.per_channel_statistics.normalize(latent)
|
||||
return latent
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,113 +0,0 @@
|
||||
from ..core.loader import load_model, hash_model_file
|
||||
from ..core.vram import AutoWrappedModule
|
||||
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
|
||||
import importlib, json, torch
|
||||
|
||||
|
||||
class ModelPool:
|
||||
def __init__(self):
|
||||
self.model = []
|
||||
self.model_name = []
|
||||
self.model_path = []
|
||||
|
||||
def import_model_class(self, model_class):
|
||||
split = model_class.rfind(".")
|
||||
model_resource, model_class = model_class[:split], model_class[split+1:]
|
||||
model_class = importlib.import_module(model_resource).__getattribute__(model_class)
|
||||
return model_class
|
||||
|
||||
def need_to_enable_vram_management(self, vram_config):
|
||||
return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None
|
||||
|
||||
def fetch_module_map(self, model_class, vram_config):
|
||||
if self.need_to_enable_vram_management(vram_config):
|
||||
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
|
||||
vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()
|
||||
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}
|
||||
else:
|
||||
module_map = {self.import_model_class(model_class): AutoWrappedModule}
|
||||
else:
|
||||
module_map = None
|
||||
return module_map
|
||||
|
||||
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
|
||||
model_class = self.import_model_class(config["model_class"])
|
||||
model_config = config.get("extra_kwargs", {})
|
||||
if "state_dict_converter" in config:
|
||||
state_dict_converter = self.import_model_class(config["state_dict_converter"])
|
||||
else:
|
||||
state_dict_converter = None
|
||||
module_map = self.fetch_module_map(config["model_class"], vram_config)
|
||||
model = load_model(
|
||||
model_class, path, model_config,
|
||||
vram_config["computation_dtype"], vram_config["computation_device"],
|
||||
state_dict_converter,
|
||||
use_disk_map=True,
|
||||
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
return model
|
||||
|
||||
def default_vram_config(self):
|
||||
vram_config = {
|
||||
"offload_dtype": None,
|
||||
"offload_device": None,
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cpu",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cpu",
|
||||
}
|
||||
return vram_config
|
||||
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
|
||||
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||
if vram_config is None:
|
||||
vram_config = self.default_vram_config()
|
||||
model_hash = hash_model_file(path)
|
||||
loaded = False
|
||||
for config in MODEL_CONFIGS:
|
||||
if config["model_hash"] == model_hash:
|
||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
|
||||
if clear_parameters: self.clear_parameters(model)
|
||||
self.model.append(model)
|
||||
model_name = config["model_name"]
|
||||
self.model_name.append(model_name)
|
||||
self.model_path.append(path)
|
||||
model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")}
|
||||
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
|
||||
loaded = True
|
||||
if not loaded:
|
||||
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
|
||||
|
||||
def fetch_model(self, model_name, index=None):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
print(f"No {model_name} models available. This is not an error.")
|
||||
model = None
|
||||
elif len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||
model = fetched_models[0]
|
||||
else:
|
||||
if index is None:
|
||||
model = fetched_models[0]
|
||||
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||
elif isinstance(index, int):
|
||||
model = fetched_models[:index]
|
||||
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.")
|
||||
else:
|
||||
model = fetched_models
|
||||
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
|
||||
return model
|
||||
|
||||
def clear_parameters(self, model: torch.nn.Module):
|
||||
for name, module in model.named_children():
|
||||
self.clear_parameters(module)
|
||||
for name, param in model.named_parameters(recurse=False):
|
||||
setattr(model, name, None)
|
||||
471
diffsynth/models/model_manager.py
Normal file
471
diffsynth/models/model_manager.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import os, torch, hashlib, json, importlib
|
||||
from safetensors import safe_open
|
||||
from torch import Tensor
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, Preset_model_id, Preset_model_website
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .lora import get_lora_loaders
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from .sd3_dit import SD3DiT
|
||||
from .sd3_vae_decoder import SD3VAEDecoder
|
||||
from .sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
from .sdxl_controlnet import SDXLControlNetUnion
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
|
||||
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from .cog_dit import CogDiT
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict
|
||||
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
||||
state_dict_converter = model_class.state_dict_converter()
|
||||
if model_resource == "civitai":
|
||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
||||
elif model_resource == "diffusers":
|
||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
||||
if isinstance(state_dict_results, tuple):
|
||||
model_state_dict, extra_kwargs = state_dict_results
|
||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
||||
model.load_state_dict(model_state_dict)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
try:
|
||||
model = model.to(device=device)
|
||||
except:
|
||||
pass
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
||||
base_state_dict = base_model.state_dict()
|
||||
base_model.to("cpu")
|
||||
del base_model
|
||||
model = model_class(**extra_kwargs)
|
||||
model.load_state_dict(base_state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(dtype=torch_dtype, device=device)
|
||||
return model
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
while True:
|
||||
for model_id in range(len(model_manager.model)):
|
||||
base_model_name = model_manager.model_name[model_id]
|
||||
if base_model_name == model_name:
|
||||
base_model_path = model_manager.model_path[model_id]
|
||||
base_model = model_manager.model[model_id]
|
||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
||||
patched_model = load_single_patch_model_from_single_file(
|
||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
||||
loaded_model_names.append(base_model_name)
|
||||
loaded_models.append(patched_model)
|
||||
model_manager.model.pop(model_id)
|
||||
model_manager.model_path.pop(model_id)
|
||||
model_manager.model_name.pop(model_id)
|
||||
break
|
||||
else:
|
||||
break
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorTemplate:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
return False
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
self.keys_hash_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
||||
if keys_hash is not None:
|
||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
# Load models without strict matching
|
||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
super().__init__(model_loader_configs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
# Split the state_dict and load from each component
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
valid_state_dict = {}
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
valid_state_dict.update(sub_state_dict)
|
||||
if super().match(file_path, valid_state_dict):
|
||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
else:
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromHuggingfaceFolder:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.architecture_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
return False
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if "architectures" not in config and "_class_name" not in config:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
loaded_model_names, loaded_models = [], []
|
||||
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
||||
for architecture in architectures:
|
||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
||||
if redirected_architecture is not None:
|
||||
architecture = redirected_architecture
|
||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromPatchedSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
loaded_model_names, loaded_models = [], []
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda",
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
file_path_list: List[str] = [],
|
||||
):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = []
|
||||
self.model_path = []
|
||||
self.model_name = []
|
||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
||||
self.model_detector = [
|
||||
ModelDetectorFromSingleFile(model_loader_configs),
|
||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
||||
]
|
||||
self.load_models(downloaded_files + file_path_list)
|
||||
|
||||
|
||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
||||
print(f"Loading models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
||||
print(f"Loading models from folder: {file_path}")
|
||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
||||
print(f"Loading patch models from file: {file_path}")
|
||||
model_names, models = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following patched models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
break
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None):
|
||||
print(f"Loading models from: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
for model_detector in self.model_detector:
|
||||
if model_detector.match(file_path, state_dict):
|
||||
model_names, models = model_detector.load(
|
||||
file_path, state_dict,
|
||||
device=self.device, torch_dtype=self.torch_dtype,
|
||||
allowed_model_names=model_names, model_manager=self
|
||||
)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
break
|
||||
else:
|
||||
print(f" We cannot detect the model type. No models are loaded.")
|
||||
|
||||
|
||||
def load_models(self, file_path_list, model_names=None):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, model_names)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if file_path is not None and file_path != model_path:
|
||||
continue
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
print(f"No {model_name} models available.")
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
else:
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
if require_model_path:
|
||||
return fetched_models[0], fetched_model_paths[0]
|
||||
else:
|
||||
return fetched_models[0]
|
||||
|
||||
|
||||
def to(self, device):
|
||||
for model in self.model:
|
||||
model.to(device)
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||
def __init__(self, max_length=1024, max_pixels=262640):
|
||||
super(NexusGenAutoregressiveModel, self).__init__()
|
||||
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
self.max_length = max_length
|
||||
self.max_pixels = max_pixels
|
||||
model_config = Qwen2_5_VLConfig(**{
|
||||
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||
},
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.49.0",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"hidden_size": 1280,
|
||||
"in_chans": 3,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"spatial_patch_size": 14,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "bfloat16"
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
||||
self.processor = None
|
||||
|
||||
|
||||
def load_processor(self, path):
|
||||
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
|
||||
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenAutoregressiveModelStateDictConverter()
|
||||
|
||||
def bound_image(self, image, max_pixels=262640):
|
||||
from qwen_vl_utils import smart_resize
|
||||
resized_height, resized_width = smart_resize(
|
||||
image.height,
|
||||
image.width,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
return image.resize((resized_width, resized_height))
|
||||
|
||||
def get_editing_msg(self, instruction):
|
||||
if '<image>' not in instruction:
|
||||
instruction = '<image> ' + instruction
|
||||
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
|
||||
return messages
|
||||
|
||||
def get_generation_msg(self, instruction):
|
||||
instruction = "Generate an image according to the following description: {}".format(instruction)
|
||||
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
||||
return messages
|
||||
|
||||
def forward(self, instruction, ref_image=None, num_img_tokens=81):
|
||||
"""
|
||||
Generate target embeddings for the given instruction and reference image.
|
||||
"""
|
||||
if ref_image is not None:
|
||||
messages = self.get_editing_msg(instruction)
|
||||
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||
else:
|
||||
messages = self.get_generation_msg(instruction)
|
||||
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||
|
||||
return output_image_embeddings
|
||||
|
||||
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
input_embeds = model.model.embed_tokens(inputs['input_ids'])
|
||||
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||
|
||||
image_mask = inputs['input_ids'] == model.config.image_token_id
|
||||
indices = image_mask.cumsum(dim=1)
|
||||
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||
|
||||
image_prefill_embeds = model.image_prefill_embeds(
|
||||
torch.arange(81, device=model.device).long()
|
||||
)
|
||||
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
||||
|
||||
position_ids, _ = model.get_rope_index(
|
||||
inputs['input_ids'],
|
||||
inputs['image_grid_thw'],
|
||||
attention_mask=inputs['attention_mask'])
|
||||
position_ids = position_ids.contiguous()
|
||||
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
||||
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
|
||||
|
||||
|
||||
class NexusGenAutoregressiveModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
||||
return state_dict
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,417 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
from transformers.modeling_rope_utils import _compute_default_rope_parameters
|
||||
self.rope_init_fn = _compute_default_rope_parameters
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
||||
# So we expand the inv_freq to shape (3, ...)
|
||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class Qwen2_5_VLAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.rope_scaling = config.rope_scaling
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
# Fix precision issues in Qwen2-VL float16 inference
|
||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||
if query_states.dtype == torch.float16:
|
||||
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
from transformers.activations import ACT2FN
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Qwen2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NexusGenImageEmbeddingMerger(nn.Module):
|
||||
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
|
||||
super().__init__()
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
from transformers.activations import ACT2FN
|
||||
config = Qwen2_5_VLConfig(**{
|
||||
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||
},
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.49.0",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"hidden_size": 1280,
|
||||
"in_chans": 3,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"spatial_patch_size": 14,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "bfloat16"
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
})
|
||||
self.config = config
|
||||
self.num_layers = num_layers
|
||||
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
|
||||
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
|
||||
nn.Linear(config.hidden_size, out_channel * expand_ratio),
|
||||
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
|
||||
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
|
||||
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
|
||||
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
|
||||
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
|
||||
|
||||
def get_position_ids(self, image_grid_thw):
|
||||
"""
|
||||
Generates position ids for the input embeddings grid.
|
||||
modified from the qwen2_vl mrope.
|
||||
"""
|
||||
batch_size = image_grid_thw.shape[0]
|
||||
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||
t, h, w = (
|
||||
image_grid_thw[0][0],
|
||||
image_grid_thw[0][1],
|
||||
image_grid_thw[0][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t.item(),
|
||||
h.item() // spatial_merge_size,
|
||||
w.item() // spatial_merge_size,
|
||||
)
|
||||
scale_h = self.base_grid[0][1].item() / h.item()
|
||||
scale_w = self.base_grid[0][2].item() / w.item()
|
||||
|
||||
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
||||
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
|
||||
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
|
||||
# 3, B, L
|
||||
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
|
||||
return position_ids
|
||||
|
||||
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
|
||||
position_ids = self.get_position_ids(embeds_grid)
|
||||
hidden_states = embeds
|
||||
if ref_embeds is not None:
|
||||
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
|
||||
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
|
||||
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
|
||||
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, position_embeddings)
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenMergerStateDictConverter()
|
||||
|
||||
|
||||
class NexusGenMergerStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
|
||||
return merger_state_dict
|
||||
|
||||
|
||||
class NexusGenAdapter(nn.Module):
|
||||
"""
|
||||
Adapter for Nexus-Gen generation decoder.
|
||||
"""
|
||||
def __init__(self, input_dim=3584, output_dim=4096):
|
||||
super(NexusGenAdapter, self).__init__()
|
||||
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
|
||||
nn.LayerNorm(output_dim), nn.ReLU(),
|
||||
nn.Linear(output_dim, output_dim),
|
||||
nn.LayerNorm(output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
return self.adapter(x)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return NexusGenAdapterStateDictConverter()
|
||||
|
||||
|
||||
class NexusGenAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
|
||||
return adapter_state_dict
|
||||
@@ -1,56 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .general_modules import RMSNorm
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
# [linear, gelu, linear]
|
||||
def __init__(self, dim: int = 3072):
|
||||
super().__init__()
|
||||
self.x_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.y_rms = RMSNorm(dim, eps=1e-6)
|
||||
self.input_proj = nn.Linear(dim, dim)
|
||||
self.act = nn.GELU()
|
||||
self.output_proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, y):
|
||||
x, y = self.x_rms(x), self.y_rms(y)
|
||||
x = self.input_proj(x + y)
|
||||
x = self.act(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
# zero initialize output_proj
|
||||
nn.init.zeros_(self.output_proj.weight)
|
||||
nn.init.zeros_(self.output_proj.bias)
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
additional_in_dim: int = 0,
|
||||
dim: int = 3072,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_in = nn.Linear(in_dim + additional_in_dim, dim)
|
||||
self.controlnet_blocks = nn.ModuleList(
|
||||
[
|
||||
BlockWiseControlBlock(dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def init_weight(self):
|
||||
nn.init.zeros_(self.img_in.weight)
|
||||
nn.init.zeros_(self.img_in.bias)
|
||||
for block in self.controlnet_blocks:
|
||||
block.init_weights()
|
||||
|
||||
def process_controlnet_conditioning(self, controlnet_conditioning):
|
||||
return self.img_in(controlnet_conditioning)
|
||||
|
||||
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||
@@ -1,685 +0,0 @@
|
||||
import torch, math, functools
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional, Union, List
|
||||
from einops import rearrange
|
||||
from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
|
||||
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
|
||||
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
|
||||
if not enable_fp8_attention:
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v)
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
origin_dtype = q.dtype
|
||||
q_std, k_std, v_std = q.std(), k.std(), v.std()
|
||||
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
|
||||
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
||||
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
||||
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
||||
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
x = x.to(origin_dtype) * v_std
|
||||
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
||||
else:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
||||
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
return x
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def apply_rotary_emb_qwen(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
|
||||
):
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class QwenEmbedRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.rope_cache = {}
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
index,
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
||||
_, height, width = video_fhw
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
required_len = max_vid_index + max(txt_seq_lens)
|
||||
cur_max_len = self.pos_freqs.shape[0]
|
||||
if required_len <= cur_max_len:
|
||||
return
|
||||
|
||||
new_max_len = math.ceil(required_len / 512) * 512
|
||||
pos_index = torch.arange(new_max_len)
|
||||
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
return
|
||||
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs.append(self.rope_cache[rope_key])
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
def forward_sampling(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
|
||||
frame_0, height_0, width_0 = video_fhw[0]
|
||||
|
||||
rope_key_0 = f"0_{height_0}_{width_0}"
|
||||
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
|
||||
h_indices = torch.linspace(0, height_0 - 1, height).long()
|
||||
w_indices = torch.linspace(0, width_0 - 1, width).long()
|
||||
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
|
||||
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
|
||||
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
|
||||
|
||||
seq_lens = frame * height * width
|
||||
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone()
|
||||
vid_freqs.append(self.rope_cache[rope_key].contiguous())
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class QwenEmbedLayer3DRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
video_fhw = [video_fhw]
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
if not isinstance(video_fhw, list):
|
||||
video_fhw = [video_fhw]
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
layer_num = len(video_fhw) - 1
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
if idx != layer_num:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
else:
|
||||
### For the condition image, we set the layer index to -1
|
||||
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||
video_freq = video_freq.to(device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class QwenFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * 4)
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(ApproximateGELU(dim, inner_dim))
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class QwenDoubleStreamAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_a,
|
||||
dim_b,
|
||||
num_heads,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = nn.Linear(dim_a, dim_a)
|
||||
self.to_k = nn.Linear(dim_a, dim_a)
|
||||
self.to_v = nn.Linear(dim_a, dim_a)
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.add_q_proj = nn.Linear(dim_b, dim_b)
|
||||
self.add_k_proj = nn.Linear(dim_b, dim_b)
|
||||
self.add_v_proj = nn.Linear(dim_b, dim_b)
|
||||
self.norm_added_q = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_added_k = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
|
||||
self.to_add_out = nn.Linear(dim_b, dim_b)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
enable_fp8_attention: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
||||
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
||||
seq_txt = txt_q.shape[1]
|
||||
|
||||
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
|
||||
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
||||
|
||||
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
||||
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
||||
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
||||
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
||||
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
||||
|
||||
joint_q = torch.cat([txt_q, img_q], dim=2)
|
||||
joint_k = torch.cat([txt_k, img_k], dim=2)
|
||||
joint_v = torch.cat([txt_v, img_v], dim=2)
|
||||
|
||||
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
|
||||
|
||||
txt_attn_output = joint_attn_out[:, :seq_txt, :]
|
||||
img_attn_output = joint_attn_out[:, seq_txt:, :]
|
||||
|
||||
img_attn_output = self.to_out(img_attn_output)
|
||||
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim),
|
||||
)
|
||||
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.attn = QwenDoubleStreamAttention(
|
||||
dim_a=dim,
|
||||
dim_b=dim,
|
||||
num_heads=num_attention_heads,
|
||||
head_dim=attention_head_dim,
|
||||
)
|
||||
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim, bias=True),
|
||||
)
|
||||
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
||||
|
||||
def _modulate(self, x, mod_params, index=None):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
if index is not None:
|
||||
# Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
|
||||
# So shift, scale, gate have shape [2*actual_batch, d]
|
||||
actual_batch = shift.size(0) // 2
|
||||
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
|
||||
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
|
||||
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
|
||||
|
||||
# index: [b, l] where b is actual batch size
|
||||
# Expand to [b, l, 1] to match feature dimension
|
||||
index_expanded = index.unsqueeze(-1) # [b, l, 1]
|
||||
|
||||
# Expand chunks to [b, 1, d] then broadcast to [b, l, d]
|
||||
shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
|
||||
shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
|
||||
scale_0_exp = scale_0.unsqueeze(1)
|
||||
scale_1_exp = scale_1.unsqueeze(1)
|
||||
gate_0_exp = gate_0.unsqueeze(1)
|
||||
gate_1_exp = gate_1.unsqueeze(1)
|
||||
|
||||
# Use torch.where to select based on index
|
||||
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
|
||||
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
|
||||
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
|
||||
else:
|
||||
shift_result = shift.unsqueeze(1)
|
||||
scale_result = scale.unsqueeze(1)
|
||||
gate_result = gate.unsqueeze(1)
|
||||
|
||||
return x * (1 + scale_result) + shift_result, gate_result
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
enable_fp8_attention = False,
|
||||
modulate_index: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
if modulate_index is not None:
|
||||
temb = torch.chunk(temb, 2, dim=0)[0]
|
||||
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
||||
|
||||
img_normed = self.img_norm1(image)
|
||||
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, index=modulate_index)
|
||||
|
||||
txt_normed = self.txt_norm1(text)
|
||||
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
||||
|
||||
img_attn_out, txt_attn_out = self.attn(
|
||||
image=img_modulated,
|
||||
text=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
enable_fp8_attention=enable_fp8_attention,
|
||||
)
|
||||
|
||||
image = image + img_gate * img_attn_out
|
||||
text = text + txt_gate * txt_attn_out
|
||||
|
||||
img_normed_2 = self.img_norm2(image)
|
||||
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, index=modulate_index)
|
||||
|
||||
txt_normed_2 = self.txt_norm2(text)
|
||||
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
||||
|
||||
img_mlp_out = self.img_mlp(img_modulated_2)
|
||||
txt_mlp_out = self.txt_mlp(txt_modulated_2)
|
||||
|
||||
image = image + img_gate_2 * img_mlp_out
|
||||
text = text + txt_gate_2 * txt_mlp_out
|
||||
|
||||
return text, image
|
||||
|
||||
|
||||
class QwenImageDiT(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
use_layer3d_rope: bool = False,
|
||||
use_additional_t_cond: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not use_layer3d_rope:
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
else:
|
||||
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
||||
|
||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=False, use_additional_t_cond=use_additional_t_cond)
|
||||
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
||||
|
||||
self.img_in = nn.Linear(64, 3072)
|
||||
self.txt_in = nn.Linear(3584, 3072)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNorm(3072, single=True)
|
||||
self.proj_out = nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
|
||||
# prompt_emb
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# image_rotary_emb
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
|
||||
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
|
||||
# attention_mask
|
||||
repeat_dim = latents.shape[1]
|
||||
max_masks = entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
entity_masks = entity_masks + [global_mask]
|
||||
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||
total_seq_len = sum(seq_lens) + image.shape[1]
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
patched_masks.append(patched_mask)
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
# prompt-image attention mask
|
||||
image_start = sum(seq_lens)
|
||||
image_end = total_seq_len
|
||||
cumsum = [0]
|
||||
single_image_seq = image_end - image_start
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
for i in range(N):
|
||||
prompt_start = cumsum[i]
|
||||
prompt_end = cumsum[i+1]
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
# repeat image mask to match the single image sequence length
|
||||
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt attention mask, let the prompt tokens not attend to each other
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i == j:
|
||||
continue
|
||||
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
prompt_emb_mask=None,
|
||||
height=None,
|
||||
width=None,
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = self.img_in(image)
|
||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
||||
|
||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
image = self.norm_out(image, conditioning)
|
||||
image = self.proj_out(image)
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return image
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user