Compare commits

..

5 Commits

Author SHA1 Message Date
mi804
2cefc20ed6 wanx tiled encode 2025-02-21 12:58:45 +08:00
mi804
02a4c8df9f wanx vae tile decode 2025-02-21 11:27:30 +08:00
mi804
582e33ad51 save_video 2025-02-20 17:57:38 +08:00
mi804
491bbf5369 support wanxvae 2025-02-20 17:44:20 +08:00
mi804
0c92f3b2cc support wanx prompter 2025-02-20 16:08:22 +08:00
999 changed files with 2202373 additions and 67906 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

View File

@@ -20,9 +20,9 @@ jobs:
with:
python-version: '3.10'
- name: Install wheel
run: pip install wheel==0.44.0 && pip install -r requirements.txt
run: pip install wheel && pip install -r requirements.txt
- name: Build DiffSynth
run: python -m build
run: python setup.py sdist bdist_wheel
- name: Publish package to PyPI
run: |
pip install twine

175
.gitignore vendored
View File

@@ -1,175 +0,0 @@
/data
/models
/scripts
/diffusers
*.pkl
*.safetensors
*.pth
*.ckpt
*.pt
*.bin
*.DS_Store
*.msc
*.mv
log*.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

1078
README.md

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,252 @@
import gradio as gr
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
import os, torch
from PIL import Image
import numpy as np
config = {
"model_config": {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
"height": 512,
"width": 512,
}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 1.0,
}
}
},
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
def load_model_list(model_type):
if model_type is None:
return []
folder = config["model_config"][model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def load_model(model_type, model_path):
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
while len(model_dict) + 1 > config["max_num_model_cache"]:
key = next(iter(model_dict.keys()))
model_manager_to_release, _ = model_dict[key]
model_manager_to_release.to("cpu")
del model_dict[key]
torch.cuda.empty_cache()
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
model_dict = {}
with gr.Blocks() as app:
gr.Markdown("# DiffSynth-Studio Painter")
with gr.Row():
with gr.Column(scale=382, min_width=100):
with gr.Accordion(label="Model"):
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
def model_type_to_model_path(model_type):
return gr.Dropdown(choices=load_model_list(model_type))
with gr.Accordion(label="Prompt"):
prompt = gr.Textbox(label="Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
with gr.Accordion(label="Image"):
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
with gr.Column():
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
triggers=model_path.change
)
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
load_model(model_type, model_path)
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config["model_config"][model_type]["default_parameters"].get("height", height)
width = config["model_config"][model_type]["default_parameters"].get("width", width)
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Painter"):
enable_local_prompt_list = []
local_prompt_list = []
mask_scale_list = []
canvas_list = []
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Layer {painter_layer_id}"):
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
label="Painter", key=f"canvas_{painter_layer_id}")
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
enable_local_prompt_list.append(enable_local_prompt)
local_prompt_list.append(local_prompt)
mask_scale_list.append(mask_scale)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
output_to_input_button = gr.Button(value="Set as input image")
painter_background = gr.State(None)
input_background = gr.State(None)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
outputs=[output_image],
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
)
local_prompts, masks, mask_scales = [], [], []
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
):
if enable_local_prompt:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
mask_scales.append(mask_scale)
input_params.update({
"local_prompts": local_prompts,
"masks": masks,
"mask_scales": mask_scales,
})
torch.manual_seed(seed)
image = pipe(**input_params)
return image
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(output_image, *canvas_list):
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = output_image.resize((w, h))
return tuple(canvas_list)
app.launch()

View File

@@ -0,0 +1,390 @@
import os
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import random
import json
import gradio as gr
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
with open(example_json, 'r') as f:
examples = json.load(f)['examples']
for idx in range(len(examples)):
example_id = examples[idx]['example_id']
entity_prompts = examples[idx]['local_prompt_list']
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
def create_canvas_data(background, masks):
if background.shape[-1] == 3:
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
layers = []
for mask in masks:
if mask is not None:
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
layer[..., -1] = mask_single_channel
layers.append(layer)
else:
layers.append(np.zeros_like(background))
composite = background.copy()
for layer in layers:
if layer.size > 0:
composite = np.where(layer[..., -1:] > 0, layer, composite)
return {
"background": background,
"layers": layers,
"composite": composite,
}
def load_example(load_example_button):
example_idx = int(load_example_button.split()[-1]) - 1
example = examples[example_idx]
result = [
50,
example["global_prompt"],
example["negative_prompt"],
example["seed"],
*example["local_prompt_list"],
]
num_entities = len(example["local_prompt_list"])
result += [""] * (config["max_num_painter_layers"] - num_entities)
masks = []
for mask in example["mask_lists"]:
mask_single_channel = np.array(mask.convert("L"))
masks.append(mask_single_channel)
for _ in range(config["max_num_painter_layers"] - len(masks)):
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
masks.append(blank_mask)
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
canvas_data_list = []
for mask in masks:
canvas_data = create_canvas_data(background, [mask])
canvas_data_list.append(canvas_data)
result.extend(canvas_data_list)
return result
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
print(f'save to {save_dir}')
os.makedirs(save_dir, exist_ok=True)
for i, mask in enumerate(masks):
save_path = os.path.join(save_dir, f'{i}.png')
mask.save(save_path)
sample = {
"global_prompt": global_prompt,
"mask_prompts": mask_prompts,
"seed": seed,
}
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
json.dump(sample, f, indent=4)
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
# Create a blank image for overlays
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
colors = [
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
]
# Generate random colors for each mask
if use_random_colors:
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
# Font settings
try:
font = ImageFont.truetype("arial", font_size) # Adjust as needed
except IOError:
font = ImageFont.load_default(font_size)
# Overlay each mask onto the overlay image
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
if mask is None:
continue
# Convert mask to RGBA mode
mask_rgba = mask.convert('RGBA')
mask_data = mask_rgba.getdata()
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
mask_rgba.putdata(new_data)
# Draw the mask prompt text on the mask
draw = ImageDraw.Draw(mask_rgba)
mask_bbox = mask.getbbox() # Get the bounding box of the mask
if mask_bbox is None:
continue
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
# Alpha composite the overlay with this mask
overlay = Image.alpha_composite(overlay, mask_rgba)
# Composite the overlay onto the original image
result = Image.alpha_composite(image.convert('RGBA'), overlay)
return result
config = {
"model_config": {
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 3.0,
"embedded_guidance": 3.5,
"num_inference_steps": 30,
}
},
},
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
model_dict = {}
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control",
),
lora_alpha=1,
)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
with gr.Blocks() as app:
gr.Markdown(
"""## EliGen: Entity-Level Controllable Text-to-Image Model
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
"""
)
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
main_interface = gr.Column(visible=False)
def initialize_model():
try:
load_model()
return {
loading_status: gr.update(value="Model loaded successfully!", visible=False),
main_interface: gr.update(visible=True),
}
except Exception as e:
print(f'Failed to load model with error: {e}')
return {
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
main_interface: gr.update(visible=True),
}
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
with main_interface:
with gr.Row():
local_prompt_list = []
canvas_list = []
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
with gr.Column(scale=382, min_width=100):
model_type = gr.State('FLUX')
model_path = gr.State('FLUX.1-dev')
with gr.Accordion(label="Global prompt"):
prompt = gr.Textbox(label="Global Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
with gr.Accordion(label="Inference Options", open=True):
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
with gr.Accordion(label="Inpaint Input Image", open=False):
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
with gr.Column():
reset_input_button = gr.Button(value="Reset Inpaint Input")
send_input_to_painter = gr.Button(value="Set as painter's background")
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
def reset_input_image(input_image):
return None
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Entity Painter"):
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Entity {painter_layer_id}"):
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
canvas = gr.ImageEditor(
canvas_size=(512, 512),
sources=None,
layers=False,
interactive=True,
image_mode="RGBA",
brush=gr.Brush(
default_size=50,
default_color="#000000",
colors=["#000000"],
),
label="Entity Mask Painter",
key=f"canvas_{painter_layer_id}",
width=width,
height=height,
)
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
local_prompt_list.append(local_prompt)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
real_output = gr.State(None)
mask_out = gr.State(None)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
outputs=[output_image, real_output, mask_out],
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
if input_image is not None:
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
input_params["enable_eligen_inpaint"] = True
local_prompt_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
)
local_prompts, masks = [], []
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
if isinstance(local_prompt, str) and len(local_prompt) > 0:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
entity_masks = None if len(masks) == 0 else masks
entity_prompts = None if len(local_prompts) == 0 else local_prompts
input_params.update({
"eligen_entity_prompts": entity_prompts,
"eligen_entity_masks": entity_masks,
})
torch.manual_seed(seed)
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
image = pipe(**input_params)
masks = [mask.resize(image.size) for mask in masks]
image_with_mask = visualize_masks(image, masks, local_prompts)
real_output = gr.State(image)
mask_out = gr.State(image_with_mask)
if return_with_mask:
return image_with_mask, real_output, mask_out
return image, real_output, mask_out
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
def send_input_to_painter_background(input_image, *canvas_list):
if input_image is None:
return tuple(canvas_list)
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = input_image.resize((w, h))
return tuple(canvas_list)
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(real_output, *canvas_list):
if real_output is None:
return tuple(canvas_list)
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = real_output.value.resize((w, h))
return tuple(canvas_list)
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
def show_output(return_with_mask, real_output, mask_out):
if return_with_mask:
return mask_out.value
else:
return real_output.value
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
def send_output_to_pipe_input(real_output):
return real_output.value
with gr.Column():
gr.Markdown("## Examples")
for i in range(0, len(examples), 2):
with gr.Row():
if i < len(examples):
example = examples[i]
with gr.Column():
example_image = gr.Image(
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
label=example["description"],
interactive=False,
width=1024,
height=512
)
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
load_example_button.click(
load_example,
inputs=[load_example_button],
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
)
if i + 1 < len(examples):
example = examples[i + 1]
with gr.Column():
example_image = gr.Image(
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
label=example["description"],
interactive=False,
width=1024,
height=512
)
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
load_example_button.click(
load_example,
inputs=[load_example_button],
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
)
app.config["show_progress"] = "hidden"
app.launch()

View File

@@ -0,0 +1,15 @@
# Set web page format
import streamlit as st
st.set_page_config(layout="wide")
# Diasble virtual VRAM on windows system
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)
st.markdown("""
# DiffSynth Studio
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
Welcome to DiffSynth Studio.
""")

View File

@@ -0,0 +1,362 @@
import torch, os, io, json, time
import numpy as np
from PIL import Image
import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
from diffsynth.data.video import crop_and_resize
config = {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"fixed_parameters": {
"height": 1024,
"width": 1024,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"fixed_parameters": {
"cfg_scale": 1.0,
}
}
}
def load_model_list(model_type):
folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def release_model():
if "model_manager" in st.session_state:
st.session_state["model_manager"].to("cpu")
del st.session_state["loaded_model_path"]
del st.session_state["model_manager"]
del st.session_state["pipeline"]
torch.cuda.empty_cache()
def load_model(model_type, model_path):
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
st.session_state.pipeline = pipeline
return model_manager, pipeline
def use_output_image_as_input(update=True):
# Search for input image
output_image_id = 0
selected_output_image = None
while True:
if f"use_output_as_input_{output_image_id}" not in st.session_state:
break
if st.session_state[f"use_output_as_input_{output_image_id}"]:
selected_output_image = st.session_state["output_images"][output_image_id]
break
output_image_id += 1
if update and selected_output_image is not None:
st.session_state["input_image"] = selected_output_image
return selected_output_image is not None
def apply_stroke_to_image(stroke_image, image):
image = np.array(image.convert("RGB")).astype(np.float32)
height, width, _ = image.shape
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
weight = stroke_image[:, :, -1:] / 255
stroke_image = stroke_image[:, :, :-1]
image = stroke_image * weight + image * (1 - weight)
image = np.clip(image, 0, 255).astype(np.uint8)
image = Image.fromarray(image)
return image
@st.cache_data
def image2bits(image):
image_byte = io.BytesIO()
image.save(image_byte, format="PNG")
image_byte = image_byte.getvalue()
return image_byte
def show_output_image(image):
st.image(image, use_column_width="always")
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
column_input, column_output = st.columns(2)
with st.sidebar:
# Select a model
with st.expander("Model", expanded=True):
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
fixed_parameters = config[model_type]["fixed_parameters"]
model_path_list = ["None"] + load_model_list(model_type)
model_path = st.selectbox("Model path", model_path_list)
# Load the model
if model_path == "None":
# No models are selected. Release VRAM.
st.markdown("No models are selected.")
release_model()
else:
# A model is selected.
model_path = os.path.join(config[model_type]["model_folder"], model_path)
if st.session_state.get("loaded_model_path", "") != model_path:
# The loaded model is not the selected model. Reload it.
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
release_model()
model_manager, pipeline = load_model(model_type, model_path)
st.markdown("Done.")
else:
# The loaded model is not the selected model. Fetch it from `st.session_state`.
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
st.markdown("Done.")
# Show parameters
with st.expander("Prompt", expanded=True):
prompt = st.text_area("Positive prompt")
if "negative_prompt" in fixed_parameters:
negative_prompt = fixed_parameters["negative_prompt"]
else:
negative_prompt = st.text_area("Negative prompt")
if "cfg_scale" in fixed_parameters:
cfg_scale = fixed_parameters["cfg_scale"]
else:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
with st.expander("Image", expanded=True):
if "num_inference_steps" in fixed_parameters:
num_inference_steps = fixed_parameters["num_inference_steps"]
else:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
if "height" in fixed_parameters:
height = fixed_parameters["height"]
else:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
if "width" in fixed_parameters:
width = fixed_parameters["width"]
else:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
num_images = st.number_input("Number of images", value=2)
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
# Other fixed parameters
denoising_strength = 1.0
repetition = 1
# Show input image
with column_input:
with st.expander("Input image (Optional)", expanded=True):
with st.container(border=True):
column_white_board, column_upload_image = st.columns([1, 2])
with column_white_board:
create_white_board = st.button("Create white board")
delete_input_image = st.button("Delete input image")
with column_upload_image:
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
if upload_image is not None:
st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
elif create_white_board:
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
else:
use_output_image_as_input()
if delete_input_image and "input_image" in st.session_state:
del st.session_state.input_image
if delete_input_image and "upload_image" in st.session_state:
del st.session_state.upload_image
input_image = st.session_state.get("input_image", None)
if input_image is not None:
with st.container(border=True):
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
with column_drawing_mode:
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
with column_color_1:
stroke_color = st.color_picker("Stroke color")
with column_color_2:
fill_color = st.color_picker("Fill color")
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
with st.container(border=True):
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
with st.container(border=True):
input_width, input_height = input_image.size
canvas_result = st_canvas(
fill_color=fill_color,
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color="rgba(255, 255, 255, 0)",
background_image=input_image,
update_streamlit=True,
height=int(512 / input_width * input_height),
width=512,
drawing_mode=drawing_mode,
key="canvas"
)
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
local_prompts, masks, mask_scales = [], [], []
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
painter_layers_json_data = []
for painter_tab_id in range(num_painter_layer):
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
canvas_result_local = st_canvas(
fill_color="#000000",
stroke_width=stroke_width,
stroke_color="#000000",
background_color="rgba(255, 255, 255, 0)",
background_image=white_board,
update_streamlit=True,
height=512,
width=512,
drawing_mode="freedraw",
key=f"canvas_{painter_tab_id}"
)
if canvas_result_local.json_data is not None:
painter_layers_json_data.append(canvas_result_local.json_data.copy())
painter_layers_json_data[-1]["prompt"] = local_prompt
if enable_local_prompt:
local_prompts.append(local_prompt)
if canvas_result_local.image_data is not None:
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
else:
mask = white_board
mask = Image.fromarray(255 - np.array(mask))
masks.append(mask)
mask_scales.append(mask_scale)
save_painter_layers = st.button("Save painter layers")
if save_painter_layers:
os.makedirs("data/painter_layers", exist_ok=True)
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
with open(json_file_path, "w") as f:
json.dump(painter_layers_json_data, f, indent=4)
st.markdown(f"Painter layers are saved in {json_file_path}.")
with column_output:
run_button = st.button("Generate image", type="primary")
auto_update = st.checkbox("Auto update", value=False)
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
image_columns = st.columns(num_image_columns)
# Run
if (run_button or auto_update) and model_path != "None":
if input_image is not None:
input_image = input_image.resize((width, height))
if canvas_result.image_data is not None:
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
output_images = []
for image_id in range(num_images * repetition):
if use_fixed_seed:
torch.manual_seed(seed + image_id)
else:
torch.manual_seed(np.random.randint(0, 10**9))
if image_id >= num_images:
input_image = output_images[image_id - num_images]
with image_columns[image_id % num_image_columns]:
progress_bar_st = st.progress(0.0)
image = pipeline(
prompt, negative_prompt=negative_prompt,
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
height=height, width=width,
input_image=input_image, denoising_strength=denoising_strength,
progress_bar_st=progress_bar_st
)
output_images.append(image)
progress_bar_st.progress(1.0)
show_output_image(image)
st.session_state["output_images"] = output_images
elif "output_images" in st.session_state:
for image_id in range(len(st.session_state.output_images)):
with image_columns[image_id % num_image_columns]:
image = st.session_state.output_images[image_id]
progress_bar = st.progress(1.0)
show_output_image(image)
if "upload_image" in st.session_state and use_output_image_as_input(update=False):
st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")

View File

@@ -0,0 +1,197 @@
import streamlit as st
st.set_page_config(layout="wide")
from diffsynth import SDVideoPipelineRunner
import os
import numpy as np
def load_model_list(folder):
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
file_list = sorted(file_list)
return file_list
def match_processor_id(model_name, supported_processor_id_list):
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
for processor_id in sorted_processor_id:
if processor_id in model_name:
return supported_processor_id_list.index(processor_id) + 1
return 0
config = {
"models": {
"model_list": [],
"textual_inversion_folder": "models/textual_inversion",
"device": "cuda",
"lora_alphas": [],
"controlnet_units": []
},
"data": {
"input_frames": None,
"controlnet_frames": [],
"output_folder": "output",
"fps": 60
},
"pipeline": {
"seed": 0,
"pipeline_inputs": {}
}
}
with st.expander("Model", expanded=True):
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
if stable_diffusion_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
if animatediff_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
column_lora, column_lora_alpha = st.columns([2, 1])
with column_lora:
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
with column_lora_alpha:
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
if sd_lora_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
config["models"]["lora_alphas"].append(lora_alpha)
with st.expander("Data", expanded=True):
with st.container(border=True):
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0)
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16)
if input_video != "":
config["data"]["input_frames"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
with st.container(border=True):
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
fps = st.number_input("FPS", value=60)
config["data"]["output_folder"] = output_video
config["data"]["fps"] = fps
with st.expander("ControlNet Units", expanded=True):
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
for controlnet_id in range(len(controlnet_units)):
with controlnet_units[controlnet_id]:
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
key=f"controlnet_ckpt_{controlnet_id}")
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
disabled=controlnet_ckpt == "None",
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
if not use_input_video_as_controlnet_input:
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0,
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16,
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
if input_video != "":
config["data"]["input_video"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
if controlnet_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
config["models"]["controlnet_units"].append({
"processor_id": processor_id,
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
"scale": controlnet_scale,
})
if use_input_video_as_controlnet_input:
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
else:
config["data"]["controlnet_frames"].append({
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
})
with st.container(border=True):
with st.expander("Seed", expanded=True):
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
else:
seed = np.random.randint(0, 10**9)
with st.expander("Textual Guidance", expanded=True):
prompt = st.text_area("Positive prompt")
negative_prompt = st.text_area("Negative prompt")
column_cfg_scale, column_clip_skip = st.columns(2)
with column_cfg_scale:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
with column_clip_skip:
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
with st.expander("Denoising", expanded=True):
column_num_inference_steps, column_denoising_strength = st.columns(2)
with column_num_inference_steps:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
with column_denoising_strength:
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
with st.expander("Efficiency", expanded=False):
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
animatediff_stride = st.slider("Animatediff stride",
min_value=1,
max_value=max(2, animatediff_batch_size),
value=max(1, animatediff_batch_size // 2),
step=1)
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
config["pipeline"]["seed"] = seed
config["pipeline"]["pipeline_inputs"] = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"clip_skip": clip_skip,
"denoising_strength": denoising_strength,
"num_inference_steps": num_inference_steps,
"animatediff_batch_size": animatediff_batch_size,
"animatediff_stride": animatediff_stride,
"unet_batch_size": unet_batch_size,
"controlnet_batch_size": controlnet_batch_size,
"cross_frame_attention": cross_frame_attention,
}
run_button = st.button("Run☢", type="primary")
if run_button:
SDVideoPipelineRunner(in_streamlit=True).run(config)

View File

@@ -1 +1,6 @@
from .core import *
from .data import *
from .models import *
from .prompters import *
from .schedulers import *
from .pipelines import *
from .controlnets import *

View File

@@ -1,2 +0,0 @@
from .model_configs import MODEL_CONFIGS
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS

View File

@@ -0,0 +1,744 @@
from typing_extensions import Literal, TypeAlias
from ..models.sd_text_encoder import SDTextEncoder
from ..models.sd_unet import SDUNet
from ..models.sd_vae_encoder import SDVAEEncoder
from ..models.sd_vae_decoder import SDVAEDecoder
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from ..models.sdxl_unet import SDXLUNet
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from ..models.sd3_dit import SD3DiT
from ..models.sd3_vae_decoder import SD3VAEDecoder
from ..models.sd3_vae_encoder import SD3VAEEncoder
from ..models.sd_controlnet import SDControlNet
from ..models.sdxl_controlnet import SDXLControlNetUnion
from ..models.sd_motion import SDMotionModel
from ..models.sdxl_motion import SDXLMotionModel
from ..models.svd_image_encoder import SVDImageEncoder
from ..models.svd_unet import SVDUNet
from ..models.svd_vae_decoder import SVDVAEDecoder
from ..models.svd_vae_encoder import SVDVAEEncoder
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
from ..models.omnigen import OmniGenTransformer
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.stepvideo_vae import StepVideoVAE
from ..models.stepvideo_dit import StepVideoModel
from ..models.wanx_vae import WanXVideoVAE
model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wanxvideo_vae"], [WanXVideoVAE], "civitai")
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
]
preset_models_on_huggingface = {
"HunyuanDiT": [
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
"stable-video-diffusion-img2vid-xt": [
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
],
"ControlNet_union_sdxl_promax": [
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# Qwen Prompt
"QwenPrompt": [
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
],
# Beautiful Prompt
"BeautifulPrompt": [
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
# Omost prompt
"OmostPrompt":[
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
# Translator
"opus-mt-zh-en": [
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
# IP-Adapter
"IP-Adapter-SD": [
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
"SDXL-vae-fp16-fix": [
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
# FLUX
"FLUX.1-dev": [
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# RIFE
"RIFE": [
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
],
# CogVideo
"CogVideoX-5B": [
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
# Stable Diffusion 3.5
"StableDiffusion3.5-large": [
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
}
preset_models_on_modelscope = {
# Hunyuan DiT
"HunyuanDiT": [
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
# Stable Video Diffusion
"stable-video-diffusion-img2vid-xt": [
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
# ExVideo
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-CogVideoX-LoRA-129f-v1": [
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
"AingDiffusion_v12": [
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
],
"Flat2DAnimerge_v45Sharp": [
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
],
"ControlNet_union_sdxl_promax": [
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"Annotators:Depth": [
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
],
"Annotators:Softedge": [
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
],
"Annotators:Lineart": [
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
],
"Annotators:Normal": [
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
],
"Annotators:Openpose": [
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# RIFE
"RIFE": [
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
],
# Qwen Prompt
"QwenPrompt": {
"file_list": [
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
],
"load_path": [
"models/QwenPrompt/qwen2-1.5b-instruct",
],
},
# Beautiful Prompt
"BeautifulPrompt": {
"file_list": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
"load_path": [
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
],
},
# Omost prompt
"OmostPrompt": {
"file_list": [
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
"load_path": [
"models/OmostPrompt/omost-llama-3-8b-4bits",
],
},
# Translator
"opus-mt-zh-en": {
"file_list": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
"load_path": [
"models/translator/opus-mt-zh-en",
],
},
# IP-Adapter
"IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": {
"file_list": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
"load_path": [
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
],
},
"SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
# FLUX
"FLUX.1-dev": {
"file_list": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
"load_path": [
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
],
},
"FLUX.1-schnell": {
"file_list": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
],
"load_path": [
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
],
},
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
],
"jasperai/Flux.1-dev-Controlnet-Depth": [
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
],
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
],
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
],
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
],
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
],
# RIFE
"RIFE": [
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
],
# Omnigen
"OmniGen-v1": {
"file_list": [
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
],
"load_path": [
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
"models/OmniGen/OmniGen-v1/model.safetensors",
]
},
# CogVideo
"CogVideoX-5B": {
"file_list": [
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
"load_path": [
"models/CogVideo/CogVideoX-5b/text_encoder",
"models/CogVideo/CogVideoX-5b/transformer",
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
],
},
# Stable Diffusion 3.5
"StableDiffusion3.5-large": [
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
"StableDiffusion3.5-medium": [
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
"StableDiffusion3.5-large-turbo": [
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
],
"HunyuanVideo":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
],
"load_path": [
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
},
"HunyuanVideo-fp8":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
],
"load_path": [
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
"models/HunyuanVideo/transformers/model.fp8.safetensors"
],
},
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
"stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1",
"ExVideo-CogVideoX-LoRA-129f-v1",
"StableDiffusion_v15",
"DreamShaper_8",
"AingDiffusion_v12",
"Flat2DAnimerge_v45Sharp",
"TextualInversion_VeryBadImageNegative_v1.3",
"StableDiffusionXL_v1",
"BluePencilXL_v200",
"StableDiffusionXL_Turbo",
"ControlNet_v11f1p_sd15_depth",
"ControlNet_v11p_sd15_softedge",
"ControlNet_v11f1e_sd15_tile",
"ControlNet_v11p_sd15_lineart",
"AnimateDiff_v2",
"AnimateDiff_xl_beta",
"RIFE",
"BeautifulPrompt",
"opus-mt-zh-en",
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5",
"Kolors",
"SDXL-vae-fp16-fix",
"ControlNet_union_sdxl_promax",
"FLUX.1-dev",
"FLUX.1-schnell",
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
"jasperai/Flux.1-dev-Controlnet-Depth",
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
"jasperai/Flux.1-dev-Controlnet-Upscaler",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"InstantX/FLUX.1-dev-IP-Adapter",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt",
"OmostPrompt",
"ESRGAN_x4",
"RIFE",
"OmniGen-v1",
"CogVideoX-5B",
"Annotators:Depth",
"Annotators:Softedge",
"Annotators:Lineart",
"Annotators:Normal",
"Annotators:Openpose",
"StableDiffusion3.5-large",
"StableDiffusion3.5-medium",
"HunyuanVideo",
"HunyuanVideo-fp8",
]

View File

@@ -1,738 +0,0 @@
qwen_image_series = [
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8004730443f55db63092006dd9f7110e",
"model_name": "qwen_image_text_encoder",
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
"model_name": "qwen_image_blockwise_controlnet",
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
"model_name": "qwen_image_blockwise_controlnet",
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
"extra_kwargs": {"additional_in_dim": 4},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
"model_name": "siglip2_image_encoder",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
"model_hash": "5722b5c873720009de96422993b15682",
"model_name": "dinov3_image_encoder",
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
},
{
# Example:
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
"model_name": "qwen_image_image2lora_coarse",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
},
{
# Example:
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
"model_name": "qwen_image_image2lora_fine",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
},
{
# Example:
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
"model_name": "qwen_image_image2lora_style",
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "8dc8cda05de16c73afa755e2c1ce2839",
"model_name": "qwen_image_dit",
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
"extra_kwargs": {"use_layer3d_rope": True, "use_additional_t_cond": True}
},
{
# Example: ModelConfig(model_id="Qwen/Qwen-Image-Layered", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "44b39ddc499e027cfb24f7878d7416b9",
"model_name": "qwen_image_vae",
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
"extra_kwargs": {"image_channels": 4}
},
]
wan_series = [
{
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
"model_name": "wan_video_text_encoder",
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
"model_name": "wan_video_vae",
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
},
{
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
},
{
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
"model_name": "wan_video_vap",
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
"model_hash": "5941c53e207d62f20f9025686193c40b",
"model_name": "wan_video_image_encoder",
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
"model_hash": "dbd5ec76bbf977983f972c151d545389",
"model_name": "wan_video_motion_controller",
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "9269f8db9040a9d860eaca435be61814",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "349723183fc063b2bfc10bb2835cf677",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "70ddad9d3a133785da5ea371aae09504",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
},
{
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
"model_name": "wan_video_vace",
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
"model_name": "wan_video_vace",
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
"model_name": "wan_video_animate_adapter",
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
},
{
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
},
{
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "2267d489f0ceb9f21836532952852ee5",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "966cffdcc52f9c46c391768b27637614",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
"model_name": "wan_video_vae",
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
"model_hash": "06be60f3a4526586d8431cd038a71486",
"model_name": "wans2v_audio_encoder",
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
},
]
flux_series = [
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
"model_hash": "a29710fea6dddb0314663ee823598e50",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Supported due to historical reasons.
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
"model_name": "flux_text_encoder_clip",
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
"model_name": "flux_text_encoder_t5",
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
"model_name": "flux_vae_encoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
"model_name": "flux_vae_decoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
"model_hash": "0629116fce1472503a66992f96f3eb1a",
"model_name": "flux_value_controller",
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
},
{
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "52357cb26250681367488a8954c271e8",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
},
{
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "78d18b9101345ff695f312e7e62538c0",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
},
{
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
"model_hash": "b001c89139b5f053c715fe772362dd2a",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_single_blocks": 0},
},
{
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
"model_name": "infiniteyou_image_projector",
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
},
{
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
"model_name": "flux_controlnet",
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
"model_name": "flux_lora_encoder",
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
"model_name": "flux_lora_patcher",
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
"model_name": "nexus_gen_llm",
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
"model_name": "nexus_gen_editing_adapter",
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
"model_name": "nexus_gen_generation_adapter",
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
{
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
"model_name": "flux_ipadapter",
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
},
{
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
"model_name": "siglip_vision_model",
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
},
{
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
"model_name": "step1x_connector",
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
},
{
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
"extra_kwargs": {"disable_guidance_embedder": True},
},
{
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
"model_name": "flux_dit",
"model_class": "diffsynth.models.flux_dit.FluxDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
},
]
flux2_series = [
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
"model_name": "flux2_text_encoder",
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "c54288e3ee12ca215898840682337b95",
"model_name": "flux2_vae",
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "3bde7b817fec8143028b6825a63180df",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "8B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
"model_name": "flux2_dit",
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
},
]
z_image_series = [
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
"model_name": "flux_vae_encoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
"model_name": "flux_vae_decoder",
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
"extra_kwargs": {"use_conv_attention": False},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors")
"model_hash": "aa3563718e5c3ecde3dfbb020ca61180",
"model_name": "z_image_dit",
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
"extra_kwargs": {"siglip_feat_dim": 1152},
},
{
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors")
"model_hash": "89d48e420f45cff95115a9f3e698d44a",
"model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
},
{
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
"model_hash": "1677708d40029ab380a95f6c731a57d7",
"model_name": "z_image_controlnet",
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
},
{
# Example: ???
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
"model_name": "z_image_image2lora_style",
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
"extra_kwargs": {"compress_dim": 128},
},
{
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
"model_hash": "1392adecee344136041e70553f875f31",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "0.6B"},
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
},
]
"""
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
and avoid redundant memory usage when users only want to use part of the model.
"""
ltx2_series = [
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors")
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors")
"model_hash": "7f7e904a53260ec0351b05f32153754b",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors")
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors")
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors")
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors")
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors")
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
"model_hash": "33917f31c4a79196171154cca39f165e",
"model_name": "ltx2_text_encoder",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
"model_name": "ltx2_latent_upsampler",
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
},
]
anima_series = [
{
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/vae/qwen_image_vae.safetensors")
"model_hash": "a9995952c2d8e63cf82e115005eb61b9",
"model_name": "z_image_text_encoder",
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
"extra_kwargs": {"model_size": "0.6B"},
},
{
# Example: ModelConfig(model_id="circlestone-labs/Anima", origin_file_pattern="split_files/diffusion_models/anima-preview.safetensors")
"model_hash": "417673936471e79e31ed4d186d7a3f4a",
"model_name": "anima_dit",
"model_class": "diffsynth.models.anima_dit.AnimaDiT",
"state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter",
}
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series

View File

@@ -1,266 +0,0 @@
flux_general_vram_config = {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
}
VRAM_MANAGEMENT_MODULE_MAPS = {
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_dit.WanModel": {
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_mot.MotWanModel": {
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vace.VaceWanModel": {
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vae.WanVideoVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_dit.FluxDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_dit.Flux2DiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.flux2_vae.Flux2VAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_dit.ZImageDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
},
"diffsynth.models.ltx2_dit.LTXModel": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.anima_dit.AnimaDiT": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():
current = VRAM_MANAGEMENT_MODULE_MAPS["diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder"]
from packaging import version
import transformers
if version.parse(transformers.__version__) >= version.parse("5.2.0"):
# The Qwen2RMSNorm in transformers 5.2.0+ has been renamed to Qwen2_5_VLRMSNorm, so we need to update the module map accordingly
current.pop("transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm", None)
current["transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRMSNorm"] = "diffsynth.core.vram.layers.AutoWrappedModule"
return current
VERSION_CHECKER_MAPS = {
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": QwenImageTextEncoder_Module_Map_Updater,
}

View File

@@ -0,0 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
from .processors import Annotator

View File

@@ -0,0 +1,91 @@
import torch
import numpy as np
from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
self.skip_processor = skip_processor
class ControlNetUnit:
def __init__(self, processor, model, scale=1.0):
self.processor = processor
self.model = model
self.scale = scale
class MultiControlNetManager:
def __init__(self, controlnet_units=[]):
self.processors = [unit.processor for unit in controlnet_units]
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def cpu(self):
for model in self.models:
model.cpu()
def to(self, device):
for model in self.models:
model.to(device)
for processor in self.processors:
processor.to(device)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
processed_image = torch.concat([
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
for image_ in processed_image
], dim=0)
return processed_image
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32, **kwargs
):
res_stack = None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
processor_id=processor.processor_id
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack
class FluxMultiControlNetManager(MultiControlNetManager):
def __init__(self, controlnet_units=[]):
super().__init__(controlnet_units=controlnet_units)
def process_image(self, image, processor_id=None):
if processor_id is None:
processed_image = [processor(image) for processor in self.processors]
else:
processed_image = [self.processors[processor_id](image)]
return processed_image
def __call__(self, conditionings, **kwargs):
res_stack, single_res_stack = None, None
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
res_stack_ = [res * scale for res in res_stack_]
single_res_stack_ = [res * scale for res in single_res_stack_]
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack

View File

@@ -1,34 +1,32 @@
from typing_extensions import Literal, TypeAlias
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
)
from diffsynth.core.device.npu_compatible_device import get_device_type
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
if not skip_processor:
if processor_id == "canny":
from controlnet_aux.processor import CannyDetector
self.processor = CannyDetector()
elif processor_id == "depth":
from controlnet_aux.processor import MidasDetector
self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge":
from controlnet_aux.processor import HEDdetector
self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart":
from controlnet_aux.processor import LineartDetector
self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime":
from controlnet_aux.processor import LineartAnimeDetector
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose":
from controlnet_aux.processor import OpenposeDetector
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "normal":
from controlnet_aux.processor import NormalBaeDetector
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
self.processor = None

View File

@@ -1,6 +0,0 @@
from .attention import *
from .data import *
from .gradient import *
from .loader import *
from .vram import *
from .device import *

View File

@@ -1 +0,0 @@
from .attention import attention_forward

View File

@@ -1,121 +0,0 @@
import torch, os
from einops import rearrange
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
try:
import xformers.ops as xops
XFORMERS_AVAILABLE = True
except ModuleNotFoundError:
XFORMERS_AVAILABLE = False
def initialize_attention_priority():
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
elif FLASH_ATTN_3_AVAILABLE:
return "flash_attention_3"
elif FLASH_ATTN_2_AVAILABLE:
return "flash_attention_2"
elif SAGE_ATTN_AVAILABLE:
return "sage_attention"
elif XFORMERS_AVAILABLE:
return "xformers"
else:
return "torch"
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if q_pattern != required_in_pattern:
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
if k_pattern != required_in_pattern:
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
if v_pattern != required_in_pattern:
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
return q, k, v
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
dims = {} if dims is None else dims
if out_pattern != required_out_pattern:
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
return out
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
if isinstance(out, tuple):
out = out[0]
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = sageattn(q, k, v, sm_scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
out = xops.memory_efficient_attention(q, k, v, scale=scale)
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
return out
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
if compatibility_mode or (attn_mask is not None):
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
else:
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "sage_attention":
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
elif ATTENTION_IMPLEMENTATION == "xformers":
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
else:
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)

View File

@@ -1 +0,0 @@
from .unified_dataset import UnifiedDataset

View File

@@ -1,237 +0,0 @@
import torch, torchvision, imageio, os
import imageio.v3 as iio
from PIL import Image
class DataProcessingPipeline:
def __init__(self, operators=None):
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
def __call__(self, data):
for operator in self.operators:
data = operator(data)
return data
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline(self.operators + pipe.operators)
class DataProcessingOperator:
def __call__(self, data):
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
def __rshift__(self, pipe):
if isinstance(pipe, DataProcessingOperator):
pipe = DataProcessingPipeline([pipe])
return DataProcessingPipeline([self]).__rshift__(pipe)
class DataProcessingOperatorRaw(DataProcessingOperator):
def __call__(self, data):
return data
class ToInt(DataProcessingOperator):
def __call__(self, data):
return int(data)
class ToFloat(DataProcessingOperator):
def __call__(self, data):
return float(data)
class ToStr(DataProcessingOperator):
def __init__(self, none_value=""):
self.none_value = none_value
def __call__(self, data):
if data is None: data = self.none_value
return str(data)
class LoadImage(DataProcessingOperator):
def __init__(self, convert_RGB=True, convert_RGBA=False):
self.convert_RGB = convert_RGB
self.convert_RGBA = convert_RGBA
def __call__(self, data: str):
image = Image.open(data)
if self.convert_RGB: image = image.convert("RGB")
if self.convert_RGBA: image = image.convert("RGBA")
return image
class ImageCropAndResize(DataProcessingOperator):
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
self.height = height
self.width = width
self.max_pixels = max_pixels
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
return image
def get_height_width(self, image):
if self.height is None or self.width is None:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def __call__(self, data: Image.Image):
image = self.crop_and_resize(data, *self.get_height_width(data))
return image
class ToList(DataProcessingOperator):
def __call__(self, data):
return [data]
class LoadVideo(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, reader):
num_frames = self.num_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
reader = imageio.get_reader(data)
num_frames = self.get_num_frames(reader)
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.frame_processor(frame)
frames.append(frame)
reader.close()
return frames
class SequencialProcess(DataProcessingOperator):
def __init__(self, operator=lambda x: x):
self.operator = operator
def __call__(self, data):
return [self.operator(i) for i in data]
class LoadGIF(DataProcessingOperator):
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# frame_processor is build in the video loader for high efficiency.
self.frame_processor = frame_processor
def get_num_frames(self, path):
num_frames = self.num_frames
images = iio.imread(path, mode="RGB")
if len(images) < num_frames:
num_frames = len(images)
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
num_frames -= 1
return num_frames
def __call__(self, data: str):
num_frames = self.get_num_frames(data)
frames = []
images = iio.imread(data, mode="RGB")
for img in images:
frame = Image.fromarray(img)
frame = self.frame_processor(frame)
frames.append(frame)
if len(frames) >= num_frames:
break
return frames
class RouteByExtensionName(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data: str):
file_ext_name = data.split(".")[-1].lower()
for ext_names, operator in self.operator_map:
if ext_names is None or file_ext_name in ext_names:
return operator(data)
raise ValueError(f"Unsupported file: {data}")
class RouteByType(DataProcessingOperator):
def __init__(self, operator_map):
self.operator_map = operator_map
def __call__(self, data):
for dtype, operator in self.operator_map:
if dtype is None or isinstance(data, dtype):
return operator(data)
raise ValueError(f"Unsupported data: {data}")
class LoadTorchPickle(DataProcessingOperator):
def __init__(self, map_location="cpu"):
self.map_location = map_location
def __call__(self, data):
return torch.load(data, map_location=self.map_location, weights_only=False)
class ToAbsolutePath(DataProcessingOperator):
def __init__(self, base_path=""):
self.base_path = base_path
def __call__(self, data):
return os.path.join(self.base_path, data)
class LoadAudio(DataProcessingOperator):
def __init__(self, sr=16000):
self.sr = sr
def __call__(self, data: str):
import librosa
input_audio, sample_rate = librosa.load(data, sr=self.sr)
return input_audio
class LoadAudioWithTorchaudio(DataProcessingOperator):
def __init__(self, duration=5):
self.duration = duration
def __call__(self, data: str):
import torchaudio
waveform, sample_rate = torchaudio.load(data)
target_samples = int(self.duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate

View File

@@ -1,116 +0,0 @@
from .operators import *
import torch, json, pandas
class UnifiedDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None, metadata_path=None,
repeat=1,
data_file_keys=tuple(),
main_data_operator=lambda x: x,
special_operator_map=None,
max_data_items=None,
):
self.base_path = base_path
self.metadata_path = metadata_path
self.repeat = repeat
self.data_file_keys = data_file_keys
self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.max_data_items = max_data_items
self.data = []
self.cached_data = []
self.load_from_cache = metadata_path is None
self.load_metadata(metadata_path)
@staticmethod
def default_image_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
])
@staticmethod
def default_video_operator(
base_path="",
max_pixels=1920*1080, height=None, width=None,
height_division_factor=16, width_division_factor=16,
num_frames=81, time_division_factor=4, time_division_remainder=1,
):
return RouteByType(operator_map=[
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
(("gif",), LoadGIF(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
num_frames, time_division_factor, time_division_remainder,
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
)),
])),
])
def search_for_cached_data_files(self, path):
for file_name in os.listdir(path):
subpath = os.path.join(path, file_name)
if os.path.isdir(subpath):
self.search_for_cached_data_files(subpath)
elif subpath.endswith(".pth"):
self.cached_data.append(subpath)
def load_metadata(self, metadata_path):
if metadata_path is None:
print("No metadata_path. Searching for cached data files.")
self.search_for_cached_data_files(self.base_path)
print(f"{len(self.cached_data)} cached data files found.")
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
elif metadata_path.endswith(".jsonl"):
metadata = []
with open(metadata_path, 'r') as f:
for line in f:
metadata.append(json.loads(line.strip()))
self.data = metadata
else:
metadata = pandas.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def __getitem__(self, data_id):
if self.load_from_cache:
data = self.cached_data[data_id % len(self.cached_data)]
data = self.cached_data_operator(data)
else:
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
if key in self.special_operator_map:
data[key] = self.special_operator_map[key](data[key])
elif key in self.data_file_keys:
data[key] = self.main_data_operator(data[key])
return data
def __len__(self):
if self.max_data_items is not None:
return self.max_data_items
elif self.load_from_cache:
return len(self.cached_data) * self.repeat
else:
return len(self.data) * self.repeat
def check_data_equal(self, data1, data2):
# Debug only
if len(data1) != len(data2):
return False
for k in data1:
if data1[k] != data2[k]:
return False
return True

View File

@@ -1,2 +0,0 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE

View File

@@ -1,107 +0,0 @@
import importlib
import torch
from typing import Any
def is_torch_npu_available():
return importlib.util.find_spec("torch_npu") is not None
IS_CUDA_AVAILABLE = torch.cuda.is_available()
IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available()
if IS_NPU_AVAILABLE:
import torch_npu
torch.npu.config.allow_internal_format = False
def get_device_type() -> str:
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
if IS_CUDA_AVAILABLE:
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"
return device
def get_torch_device() -> Any:
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
device_name = get_device_type()
try:
return getattr(torch, device_name)
except AttributeError:
print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
return torch.cuda
def get_device_id() -> int:
"""Get current device id based on device type."""
return get_torch_device().current_device()
def get_device_name() -> str:
"""Get current device name based on device type."""
return f"{get_device_type()}:{get_device_id()}"
def synchronize() -> None:
"""Execute torch synchronize operation."""
get_torch_device().synchronize()
def empty_cache() -> None:
"""Execute torch empty cache operation."""
get_torch_device().empty_cache()
def get_nccl_backend() -> str:
"""Return distributed communication backend type based on device type."""
if IS_CUDA_AVAILABLE:
return "nccl"
elif IS_NPU_AVAILABLE:
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
def enable_high_precision_for_bf16():
"""
Set high accumulation dtype for matmul and reduction.
"""
if IS_CUDA_AVAILABLE:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
if IS_NPU_AVAILABLE:
torch.npu.matmul.allow_tf32 = False
torch.npu.matmul.allow_bf16_reduced_precision_reduction = False
def parse_device_type(device):
if isinstance(device, str):
if device.startswith("cuda"):
return "cuda"
elif device.startswith("npu"):
return "npu"
else:
return "cpu"
elif isinstance(device, torch.device):
return device.type
def parse_nccl_backend(device_type):
if device_type == "cuda":
return "nccl"
elif device_type == "npu":
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.")
def get_available_device_type():
return get_device_type()

View File

@@ -1 +0,0 @@
from .gradient_checkpoint import gradient_checkpoint_forward

View File

@@ -1,34 +0,0 @@
import torch
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output

View File

@@ -1,3 +0,0 @@
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
from .model import load_model, load_model_with_disk_offload
from .config import ModelConfig

View File

@@ -1,119 +0,0 @@
import torch, glob, os
from typing import Optional, Union, Dict
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
from typing import Optional
@dataclass
class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_source: str = None
local_model_path: str = None
skip_download: bool = None
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
onload_device: Optional[Union[str, torch.device]] = None
onload_dtype: Optional[torch.dtype] = None
preparing_device: Optional[Union[str, torch.device]] = None
preparing_dtype: Optional[torch.dtype] = None
computation_device: Optional[Union[str, torch.device]] = None
computation_dtype: Optional[torch.dtype] = None
clear_parameters: bool = False
state_dict: Dict[str, torch.Tensor] = None
def check_input(self):
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def parse_original_file_pattern(self):
if self.origin_file_pattern in [None, "", "./"]:
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
else:
return self.origin_file_pattern
def parse_download_source(self):
if self.download_source is None:
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
else:
return "modelscope"
else:
return self.download_source
def parse_skip_download(self):
if self.skip_download is None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
return True
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
return False
else:
return False
else:
return self.skip_download
def download(self):
origin_file_pattern = self.parse_original_file_pattern()
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
download_source = self.parse_download_source()
if download_source.lower() == "modelscope":
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
elif download_source.lower() == "huggingface":
hf_snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_patterns=origin_file_pattern,
ignore_patterns=downloaded_files,
local_files_only=False
)
else:
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
def require_downloading(self):
if self.path is not None:
return False
skip_download = self.parse_skip_download()
return not skip_download
def reset_local_model_path(self):
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
elif self.local_model_path is None:
self.local_model_path = "./models"
def download_if_necessary(self):
self.check_input()
self.reset_local_model_path()
if self.require_downloading():
self.download()
if self.path is None:
if self.origin_file_pattern in [None, "", "./"]:
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
def vram_config(self):
return {
"offload_device": self.offload_device,
"offload_dtype": self.offload_dtype,
"onload_device": self.onload_device,
"onload_dtype": self.onload_dtype,
"preparing_device": self.preparing_device,
"preparing_dtype": self.preparing_dtype,
"computation_device": self.computation_device,
"computation_dtype": self.computation_dtype,
}

View File

@@ -1,130 +0,0 @@
from safetensors import safe_open
import torch, hashlib
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
else:
if verbose >= 1:
print(f"Loading file [started]: {file_path}")
if file_path.endswith(".safetensors"):
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
if pin_memory:
for i in state_dict:
state_dict[i] = state_dict[i].pin_memory()
if verbose >= 1:
print(f"Loading file [done]: {file_path}")
return state_dict
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
state_dict = {}
with safe_open(file_path, framework="pt", device=str(device)) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
if len(state_dict) == 1:
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
elif "module" in state_dict:
state_dict = state_dict["module"]
elif "model_state" in state_dict:
state_dict = state_dict["model_state"]
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, torch.Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_keys_dict(file_path):
if isinstance(file_path, list):
state_dict = {}
for file_path_ in file_path:
state_dict.update(load_keys_dict(file_path_))
return state_dict
if file_path.endswith(".safetensors"):
return load_keys_dict_from_safetensors(file_path)
else:
return load_keys_dict_from_bin(file_path)
def load_keys_dict_from_safetensors(file_path):
keys_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
keys_dict[k] = f.get_slice(k).get_shape()
return keys_dict
def convert_state_dict_to_keys_dict(state_dict):
keys_dict = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
keys_dict[k] = list(v.shape)
else:
keys_dict[k] = convert_state_dict_to_keys_dict(v)
return keys_dict
def load_keys_dict_from_bin(file_path):
state_dict = load_state_dict_from_bin(file_path)
keys_dict = convert_state_dict_to_keys_dict(state_dict)
return keys_dict
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, dict):
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
else:
if with_shape:
shape = "_".join(map(str, list(value)))
keys.append(key + ":" + shape)
keys.append(key)
keys.sort()
keys_str = ",".join(keys)
return keys_str
def hash_model_file(path, with_shape=True):
keys_dict = load_keys_dict(path)
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()

View File

@@ -1,105 +0,0 @@
from ..vram.initialization import skip_model_initialization
from ..vram.disk_map import DiskMap
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
from contextlib import contextmanager
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
config = {} if config is None else config
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
if module_map is not None:
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
device = [d for d in devices if d != "disk"][0]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
else:
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
else:
# Why do we use `DiskMap`?
# Sometimes a model file contains multiple models,
# and DiskMap can load only the parameters of a single model,
# avoiding the need to load all parameters in the file.
if state_dict is not None:
pass
elif use_disk_map:
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
else:
state_dict = load_state_dict(path, torch_dtype, device)
# Why do we use `state_dict_converter`?
# Some models are saved in complex formats,
# and we need to convert the state dict into the appropriate format.
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
# Why does DeepSpeed ZeRO Stage 3 need to be handled separately?
# Because at this stage, model parameters are partitioned across multiple GPUs.
# Loading them directly could lead to excessive GPU memory consumption.
if is_deepspeed_zero3_enabled():
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
_load_state_dict_into_zero3_model(model, state_dict)
else:
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`?
# Because some models override the behavior of `to()`,
# especially those from libraries like Transformers.
model = model.to(dtype=torch_dtype, device=device)
if hasattr(model, "eval"):
model = model.eval()
return model
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
if isinstance(path, str):
path = [path]
config = {} if config is None else config
with skip_model_initialization():
model = model_class(**config)
if hasattr(model, "eval"):
model = model.eval()
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch_dtype,
"computation_device": device,
}
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model
def get_init_context(torch_dtype, device):
if is_deepspeed_zero3_enabled():
from transformers.modeling_utils import set_zero3_state
import deepspeed
# Why do we use "deepspeed.zero.Init"?
# Weight segmentation of the model can be performed on the CPU side
# and loading the segmented weights onto the computing card
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
else:
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
init_contexts = [skip_model_initialization()]
return init_contexts

View File

@@ -1,30 +0,0 @@
import torch
from ..device.npu_compatible_device import get_device_type
try:
import torch_npu
except:
pass
def rms_norm_forward_npu(self, hidden_states):
"npu rms fused operator for RMSNorm.forward from diffsynth\models\general_modules.py"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.eps)[0]
def rms_norm_forward_transformers_npu(self, hidden_states):
"npu rms fused operator for transformers"
if hidden_states.dtype != self.weight.dtype:
hidden_states = hidden_states.to(self.weight.dtype)
return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor):
"npu rope fused operator for Zimage"
with torch.amp.autocast(get_device_type(), enabled=False):
freqs_cis = freqs_cis.unsqueeze(2)
cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1)
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in)

View File

@@ -1,2 +0,0 @@
from .initialization import skip_model_initialization
from .layers import *

View File

@@ -1,93 +0,0 @@
from safetensors import safe_open
import torch, os
class SafetensorsCompatibleTensor:
def __init__(self, tensor):
self.tensor = tensor
def get_shape(self):
return list(self.tensor.shape)
class SafetensorsCompatibleBinaryLoader:
def __init__(self, path, device):
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
self.state_dict = torch.load(path, weights_only=True, map_location=device)
def keys(self):
return self.state_dict.keys()
def get_tensor(self, name):
return self.state_dict[name]
def get_slice(self, name):
return SafetensorsCompatibleTensor(self.state_dict[name])
class DiskMap:
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
self.path = path if isinstance(path, list) else [path]
self.device = device
self.torch_dtype = torch_dtype
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
else:
self.buffer_size = buffer_size
self.files = []
self.flush_files()
self.name_map = {}
for file_id, file in enumerate(self.files):
for name in file.keys():
self.name_map[name] = file_id
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
def flush_files(self):
if len(self.files) == 0:
for path in self.path:
if path.endswith(".safetensors"):
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
else:
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
else:
for i, path in enumerate(self.path):
if path.endswith(".safetensors"):
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
self.num_params = 0
def __getitem__(self, name):
if self.rename_dict is not None: name = self.rename_dict[name]
file_id = self.name_map[name]
param = self.files[file_id].get_tensor(name)
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
param = param.to(self.torch_dtype)
if isinstance(param, torch.Tensor) and param.device == "cpu":
param = param.clone()
if isinstance(param, torch.Tensor):
self.num_params += param.numel()
if self.num_params > self.buffer_size:
self.flush_files()
return param
def fetch_rename_dict(self, state_dict_converter):
if state_dict_converter is None:
return None
state_dict = {}
for file in self.files:
for name in file.keys():
state_dict[name] = name
state_dict = state_dict_converter(state_dict)
return state_dict
def __iter__(self):
if self.rename_dict is not None:
return self.rename_dict.__iter__()
else:
return self.name_map.__iter__()
def __contains__(self, x):
if self.rename_dict is not None:
return x in self.rename_dict
else:
return x in self.name_map

View File

@@ -1,21 +0,0 @@
import torch
from contextlib import contextmanager
@contextmanager
def skip_model_initialization(device=torch.device("meta")):
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
old_register_parameter = torch.nn.Module.register_parameter
torch.nn.Module.register_parameter = register_empty_parameter
try:
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter

View File

@@ -1,479 +0,0 @@
import torch, copy
from typing import Union
from .initialization import skip_model_initialization
from .disk_map import DiskMap
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
class AutoTorchModule(torch.nn.Module):
def __init__(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
super().__init__()
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.state = 0
self.name = ""
self.computation_device_type = parse_device_type(self.computation_device)
def set_dtype_and_device(
self,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
):
self.offload_dtype = offload_dtype or computation_dtype
self.offload_device = offload_device or computation_dtype
self.onload_dtype = onload_dtype or computation_dtype
self.onload_device = onload_device or computation_dtype
self.preparing_dtype = preparing_dtype or computation_dtype
self.preparing_device = preparing_device or computation_dtype
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
def cast_to(self, weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def check_free_vram(self):
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
return used_memory < self.vram_limit
def offload(self):
if self.state != 0:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state != 1:
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def param_name(self, name):
if self.name == "":
return name
else:
return self.name + "." + name
class AutoWrappedModule(AutoTorchModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.module = module
if offload_dtype == "disk":
self.name = name
self.disk_map = disk_map
self.required_params = [name for name, _ in self.module.named_parameters()]
self.disk_offload = True
else:
self.disk_offload = False
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True)
module.to(dtype=torch_dtype, device=device)
return module
def offload_to_disk(self, model: torch.nn.Module):
for buf in model.buffers():
# If there are some parameters are registed in buffers (not in state dict),
# We cannot offload the model.
for children in model.children():
self.offload_to_disk(children)
break
else:
model.to("meta")
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.offload_to_disk(self.module)
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def cast_to(self, module, dtype, device):
return copy.deepcopy(module).to(dtype=dtype, device=device)
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
module = self.module
elif self.disk_offload and device == "disk":
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
else:
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
return module
def forward(self, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
module = self.computation()
return module(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedNonRecurseModule(AutoWrappedModule):
def __init__(
self,
module: torch.nn.Module,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
super().__init__(
module,
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
name,
disk_map,
**kwargs
)
if self.disk_offload:
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
def load_from_disk(self, torch_dtype, device, copy_module=False):
if copy_module:
module = copy.deepcopy(self.module)
else:
module = self.module
state_dict = {}
for name in self.required_params:
param = self.disk_map[self.param_name(name)]
param = param.to(dtype=torch_dtype, device=device)
state_dict[name] = param
module.load_state_dict(state_dict, assign=True, strict=False)
return module
def offload_to_disk(self, model: torch.nn.Module):
for name in self.required_params:
getattr(self, name).to("meta")
def cast_to(self, module, dtype, device):
# Parameter casting is implemented in the model architecture.
return module
def __getattr__(self, name):
if name in self.__dict__ or name == "module":
return super().__getattr__(name)
else:
return getattr(self.module, name)
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(
self,
module: torch.nn.Linear,
offload_dtype: torch.dtype = None,
offload_device: Union[str, torch.device] = None,
onload_dtype: torch.dtype = None,
onload_device: Union[str, torch.device] = None,
preparing_dtype: torch.dtype = None,
preparing_device: Union[str, torch.device] = None,
computation_dtype: torch.dtype = None,
computation_device: Union[str, torch.device] = None,
vram_limit: float = None,
name: str = "",
disk_map: DiskMap = None,
**kwargs
):
with skip_model_initialization():
super().__init__(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
)
self.set_dtype_and_device(
offload_dtype,
offload_device,
onload_dtype,
onload_device,
preparing_dtype,
preparing_device,
computation_dtype,
computation_device,
vram_limit,
)
self.weight = module.weight
self.bias = module.bias
self.state = 0
self.name = name
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
self.computation_device_type = parse_device_type(self.computation_device)
if offload_dtype == "disk":
self.disk_map = disk_map
self.disk_offload = True
else:
self.disk_offload = False
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def load_from_disk(self, torch_dtype, device, assign=True):
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
if assign:
state_dict = {"weight": weight}
if bias is not None: state_dict["bias"] = bias
self.load_state_dict(state_dict, assign=True)
return weight, bias
def offload(self):
# offload / onload / preparing -> offload
if self.state != 0:
if self.disk_offload:
self.to("meta")
else:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
# offload / onload / preparing -> onload
if self.state < 1:
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
self.load_from_disk(self.onload_dtype, self.onload_device)
elif self.onload_device != "disk":
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def preparing(self):
# onload / preparing -> preparing
if self.state != 2:
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
self.load_from_disk(self.preparing_dtype, self.preparing_device)
elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2
def computation(self):
# onload / preparing -> computation (temporary)
if self.state == 2:
torch_dtype, device = self.preparing_dtype, self.preparing_device
else:
torch_dtype, device = self.onload_dtype, self.onload_device
if torch_dtype == self.computation_dtype and device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.disk_offload and device == "disk":
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
else:
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
return weight, bias
def linear_forward(self, x, weight, bias):
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
return out
def lora_forward(self, x, out):
if self.lora_merger is None:
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
else:
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
lora_output.append(x @ lora_A.T @ lora_B.T)
lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output)
return out
def forward(self, x, *args, **kwargs):
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
self.preparing()
weight, bias = self.computation()
out = self.linear_forward(x, weight, bias)
if len(self.lora_A_weights) > 0:
out = self.lora_forward(x, out)
return out
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
if isinstance(model, AutoWrappedNonRecurseModule):
model = model.module
for name, module in model.named_children():
layer_name = name if name_prefix == "" else name_prefix + "." + name
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
if isinstance(module_, AutoWrappedNonRecurseModule):
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
setattr(model, name, module_)
break
else:
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
def fill_vram_config(model, vram_config):
vram_config_ = vram_config.copy()
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
vram_config_["onload_device"] = vram_config["computation_device"]
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
vram_config_["preparing_device"] = vram_config["computation_device"]
for k in vram_config:
if vram_config[k] != vram_config_[k]:
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
break
return vram_config_
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
for source_module, target_module in module_map.items():
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
if isinstance(model, source_module):
vram_config = fill_vram_config(model, vram_config)
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
break
else:
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
model.vram_management_enabled = True
return model

View File

@@ -0,0 +1 @@
from .video import VideoData, save_video, save_frames

View File

@@ -0,0 +1,41 @@
import torch, os, torchvision
from torchvision import transforms
import pandas as pd
from PIL import Image
class TextImageDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
self.steps_per_epoch = steps_per_epoch
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.height = height
self.width = width
self.image_processor = transforms.Compose(
[
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB")
target_height, target_width = self.height, self.width
width, height = image.size
scale = max(target_width / width, target_height / height)
shape = [round(height*scale),round(width*scale)]
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
image = self.image_processor(image)
return {"text": text, "image": image}
def __len__(self):
return self.steps_per_epoch

View File

@@ -2,8 +2,6 @@ import imageio, os
import numpy as np
from PIL import Image
from tqdm import tqdm
import subprocess
import shutil
class LowMemoryVideo:
@@ -116,7 +114,7 @@ class VideoData:
if self.height is not None and self.width is not None:
return self.height, self.width
else:
width, height = self.__getitem__(0).size
height, width, _ = self.__getitem__(0).shape
return height, width
def __getitem__(self, item):
@@ -148,70 +146,3 @@ def save_frames(frames, save_path):
os.makedirs(save_path, exist_ok=True)
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
frame.save(os.path.join(save_path, f"{i}.png"))
def merge_video_audio(video_path: str, audio_path: str):
# TODO: may need a in-python implementation to avoid subprocess dependency
"""
Merge the video and audio into a new video, with the duration set to the shorter of the two,
and overwrite the original video file.
Parameters:
video_path (str): Path to the original video file
audio_path (str): Path to the audio file
"""
# check
if not os.path.exists(video_path):
raise FileNotFoundError(f"video file {video_path} does not exist")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"audio file {audio_path} does not exist")
base, ext = os.path.splitext(video_path)
temp_output = f"{base}_temp{ext}"
try:
# create ffmpeg command
command = [
'ffmpeg',
'-y', # overwrite
'-i',
video_path,
'-i',
audio_path,
'-c:v',
'copy', # copy video stream
'-c:a',
'aac', # use AAC audio encoder
'-b:a',
'192k', # set audio bitrate (optional)
'-map',
'0:v:0', # select the first video stream
'-map',
'1:a:0', # select the first audio stream
'-shortest', # choose the shortest duration
temp_output
]
# execute the command
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# check result
if result.returncode != 0:
error_msg = f"FFmpeg execute failed: {result.stderr}"
print(error_msg)
raise RuntimeError(error_msg)
shutil.move(temp_output, video_path)
print(f"Merge completed, saved to {video_path}")
except Exception as e:
if os.path.exists(temp_output):
os.remove(temp_output)
print(f"merge_video_audio failed with error: {e}")
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
save_video(frames, save_path, fps, quality, ffmpeg_params)
merge_video_audio(save_path, audio_path)

View File

@@ -1,6 +0,0 @@
from .flow_match import FlowMatchScheduler
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
from .runner import launch_training_task, launch_data_process_task
from .parsers import *
from .loss import *

View File

@@ -1,462 +0,0 @@
from PIL import Image
import torch
import numpy as np
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
from ..core.device import get_device_name, IS_NPU_AVAILABLE
class PipelineUnit:
def __init__(
self,
seperate_cfg: bool = False,
take_over: bool = False,
input_params: tuple[str] = None,
output_params: tuple[str] = None,
input_params_posi: dict[str, str] = None,
input_params_nega: dict[str, str] = None,
onload_model_names: tuple[str] = None
):
self.seperate_cfg = seperate_cfg
self.take_over = take_over
self.input_params = input_params
self.output_params = output_params
self.input_params_posi = input_params_posi
self.input_params_nega = input_params_nega
self.onload_model_names = onload_model_names
def fetch_input_params(self):
params = []
if self.input_params is not None:
for param in self.input_params:
params.append(param)
if self.input_params_posi is not None:
for _, param in self.input_params_posi.items():
params.append(param)
if self.input_params_nega is not None:
for _, param in self.input_params_nega.items():
params.append(param)
params = sorted(list(set(params)))
return params
def fetch_output_params(self):
params = []
if self.output_params is not None:
for param in self.output_params:
params.append(param)
return params
def process(self, pipe, **kwargs) -> dict:
return {}
def post_process(self, pipe, **kwargs) -> dict:
return {}
class BasePipeline(torch.nn.Module):
def __init__(
self,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
super().__init__()
# The device and torch_dtype is used for the storage of intermediate variables, not models.
self.device = device
self.torch_dtype = torch_dtype
self.device_type = parse_device_type(device)
# The following parameters are used for shape check.
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
# VRAM management
self.vram_management_enabled = False
# Pipeline Unit Runner
self.unit_runner = PipelineUnitRunner()
# LoRA Loader
self.lora_loader = GeneralLoRALoader
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
# Shape check
if height % self.height_division_factor != 0:
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
if verbose > 0:
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
if width % self.width_division_factor != 0:
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
if verbose > 0:
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
if num_frames is None:
return height, width
else:
if num_frames % self.time_division_factor != self.time_division_remainder:
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
if verbose > 0:
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
return height, width, num_frames
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
return image
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to PIL.Image
if pattern != "H W C":
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
image = image.to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(image.numpy())
return image
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a torch.Tensor to list of PIL.Image
if pattern != "T H W C":
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
return video
def load_models_to_device(self, model_names):
if self.vram_management_enabled:
# offload models
for name, model in self.named_children():
if name not in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "offload"):
model.offload()
else:
for module in model.modules():
if hasattr(module, "offload"):
module.offload()
getattr(torch, self.device_type).empty_cache()
# onload models
for name, model in self.named_children():
if name in model_names:
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
if hasattr(model, "onload"):
model.onload()
else:
for module in model.modules():
if hasattr(module, "onload"):
module.onload()
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
# Initialize Gaussian noise
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
return noise
def get_vram(self):
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
def get_module(self, model, name):
if "." in name:
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
if name.isdigit():
return self.get_module(model[int(name)], suffix)
else:
return self.get_module(getattr(model, name), suffix)
else:
return getattr(model, name)
def freeze_except(self, model_names):
self.eval()
self.requires_grad_(False)
for name in model_names:
module = self.get_module(self, name)
if module is None:
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
continue
module.train()
module.requires_grad_(True)
def blend_with_mask(self, base, addition, mask):
return base * (1 - mask) + addition * mask
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
timestep = scheduler.timesteps[progress_id]
if inpaint_mask is not None:
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
def split_pipeline_units(self, model_names: list[str]):
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
def flush_vram_management_device(self, device):
for module in self.modules():
if isinstance(module, AutoTorchModule):
module.offload_device = device
module.onload_device = device
module.preparing_device = device
module.computation_device = device
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=None,
state_dict=None,
verbose=1,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
lora = lora_loader.convert_state_dict(lora)
if hotload is None:
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
if hotload:
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
updated_num = 0
for _, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
name = module.name
lora_a_name = f'{name}.lora_A.weight'
lora_b_name = f'{name}.lora_B.weight'
if lora_a_name in lora and lora_b_name in lora:
updated_num += 1
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
if verbose >= 1:
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
else:
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
def clear_lora(self, verbose=1):
cleared_num = 0
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
if len(module.lora_A_weights) > 0:
cleared_num += 1
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
if verbose >= 1:
print(f"{cleared_num} LoRA layers are cleared.")
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
model_pool = ModelPool()
for model_config in model_configs:
model_config.download_if_necessary()
vram_config = model_config.vram_config()
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
vram_config["computation_device"] = vram_config["computation_device"] or self.device
model_pool.auto_load_model(
model_config.path,
vram_config=vram_config,
vram_limit=vram_limit,
clear_parameters=model_config.clear_parameters,
state_dict=model_config.state_dict,
)
return model_pool
def check_vram_management_state(self):
vram_management_enabled = False
for module in self.children():
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
vram_management_enabled = True
return vram_management_enabled
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if isinstance(noise_pred_posi, tuple):
# Separately handling different output types of latents, eg. video and audio latents.
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
class PipelineUnitGraph:
def __init__(self):
pass
def build_edges(self, units: list[PipelineUnit]):
# Establish dependencies between units
# to search for subsequent related computation units.
last_compute_unit_id = {}
edges = []
for unit_id, unit in enumerate(units):
for input_param in unit.fetch_input_params():
if input_param in last_compute_unit_id:
edges.append((last_compute_unit_id[input_param], unit_id))
for output_param in unit.fetch_output_params():
last_compute_unit_id[output_param] = unit_id
return edges
def build_chains(self, units: list[PipelineUnit]):
# Establish updating chains for each variable
# to track their computation process.
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
params = sorted(list(set(params)))
chains = {param: [] for param in params}
for unit_id, unit in enumerate(units):
for param in unit.fetch_output_params():
chains[param].append(unit_id)
return chains
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
# Search for units that directly participate in the model's computation.
related_unit_ids = []
for unit_id, unit in enumerate(units):
for model_name in model_names:
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
related_unit_ids.append(unit_id)
break
return related_unit_ids
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
# Search for subsequent related computation units.
related_unit_ids = [unit_id for unit_id in start_unit_ids]
while True:
neighbors = []
for source, target in edges:
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
neighbors.append(target)
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
neighbors.append(source)
neighbors = sorted(list(set(neighbors)))
if len(neighbors) == 0:
break
else:
related_unit_ids.extend(neighbors)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
# If the input parameters of this subgraph are updated outside the subgraph,
# search for the units where these updates occur.
first_compute_unit_id = {}
for unit_id in related_unit_ids:
for param in units[unit_id].fetch_input_params():
if param not in first_compute_unit_id:
first_compute_unit_id[param] = unit_id
updating_unit_ids = []
for param in first_compute_unit_id:
unit_id = first_compute_unit_id[param]
chain = chains[param]
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
for unit_id_ in chain[chain.index(unit_id) + 1:]:
if unit_id_ not in related_unit_ids:
updating_unit_ids.append(unit_id_)
related_unit_ids.extend(updating_unit_ids)
related_unit_ids = sorted(list(set(related_unit_ids)))
return related_unit_ids
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
# Split the computation graph,
# separating all model-related computations.
related_unit_ids = self.search_direct_unit_ids(units, model_names)
edges = self.build_edges(units)
chains = self.build_chains(units)
while True:
num_related_unit_ids = len(related_unit_ids)
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
if len(related_unit_ids) == num_related_unit_ids:
break
else:
num_related_unit_ids = len(related_unit_ids)
related_units = [units[i] for i in related_unit_ids]
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
return related_units, unrelated_units
class PipelineUnitRunner:
def __init__(self):
pass
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
if unit.take_over:
# Let the pipeline unit take over this function.
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
elif unit.seperate_cfg:
# Positive side
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_posi.update(processor_outputs)
# Negative side
if inputs_shared["cfg_scale"] != 1:
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
if unit.input_params is not None:
for name in unit.input_params:
processor_inputs[name] = inputs_shared.get(name)
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_nega.update(processor_outputs)
else:
inputs_nega.update(processor_outputs)
else:
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
processor_outputs = unit.process(pipe, **processor_inputs)
inputs_shared.update(processor_outputs)
return inputs_shared, inputs_posi, inputs_nega

View File

@@ -1,236 +0,0 @@
import torch, math
from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@staticmethod
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.003/1.002
sigma_max = 1.0
shift = 3 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 5 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
@staticmethod
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
shift_terminal = 0.02
# Sigmas
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
# Mu
if exponential_shift_mu is not None:
mu = exponential_shift_mu
elif dynamic_shift_len is not None:
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
else:
mu = 0.8
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1 - sigmas
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
sigmas = 1 - (one_minus_z / scale_factor)
# Timesteps
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
base_shift = math.log(3)
max_shift = math.log(3)
# Sigmas
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
# Mu
if exponential_shift_mu is not None:
mu = exponential_shift_mu
elif dynamic_shift_len is not None:
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
else:
mu = 0.8
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
# Timesteps
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def compute_empirical_mu(image_seq_len, num_steps):
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
@staticmethod
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
sigma_min = 1 / num_inference_steps
sigma_max = 1.0
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
if dynamic_shift_len is None:
# If you ask me why I set mu=0.8,
# I can only say that it yields better training results.
mu = 0.8
else:
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 3 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
if target_timesteps is not None:
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
for timestep in target_timesteps:
timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep
return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
num_train_timesteps = 1000
if special_case == "stage2":
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
elif special_case == "ditilled_stage1":
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
else:
dynamic_shift_len = dynamic_shift_len or 4096
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
image_seq_len=dynamic_shift_len,
base_seq_len=1024,
max_seq_len=4096,
base_shift=0.95,
max_shift=2.05,
)
sigma_min = 0.0
sigma_max = 1.0
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
# Shift terminal
one_minus_z = 1.0 - sigmas
scale_factor = one_minus_z[-1] / (1 - terminal)
sigmas = 1.0 - (one_minus_z / scale_factor)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
def set_training_weight(self):
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
if len(self.timesteps) != 1000:
# This is an empirical formula.
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
self.sigmas, self.timesteps = self.set_timesteps_fn(
num_inference_steps=num_inference_steps,
denoising_strength=denoising_strength,
**kwargs,
)
if training:
self.set_training_weight()
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights

View File

@@ -1,43 +0,0 @@
import os, torch
from accelerate import Accelerator
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
self.num_steps = 0
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
if save_steps is not None and self.num_steps % save_steps != 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, file_name)
accelerator.save(state_dict, path, safe_serialization=True)

View File

@@ -1,158 +0,0 @@
from .base_pipeline import BasePipeline
import torch
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
if "first_frame_latents" in inputs:
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
if "first_frame_latents" in inputs:
noise_pred = noise_pred[:, :, 1:]
training_target = training_target[:, :, 1:]
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
# video
noise = torch.randn_like(inputs["input_latents"])
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
# audio
if inputs.get("audio_input_latents") is not None:
audio_noise = torch.randn_like(inputs["audio_input_latents"])
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
if inputs.get("audio_input_latents") is not None:
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
loss = loss + loss_audio
return loss
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
return loss
class TrajectoryImitationLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.initialized = False
def initialize(self, device):
import lpips # TODO: remove it
self.loss_fn = lpips.LPIPS(net='alex').to(device)
self.initialized = True
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
trajectory = [inputs_shared["latents"].clone()]
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
trajectory.append(inputs_shared["latents"].clone())
return pipe.scheduler.timesteps, trajectory
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
loss = 0
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
sigma = pipe.scheduler.sigmas[progress_id]
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
if progress_id + 1 >= len(pipe.scheduler.timesteps):
latents_ = trajectory_teacher[-1]
else:
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher]
denom = sigma_ - sigma
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
target = (latents_ - inputs_shared["latents"]) / denom
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
inputs_shared["latents"] = trajectory_teacher[0]
pipe.scheduler.set_timesteps(num_inference_steps)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
noise_pred = pipe.cfg_guided_model_fn(
pipe.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
image_pred = pipe.vae_decoder(inputs_shared["latents"])
image_real = pipe.vae_decoder(trajectory_teacher[-1])
loss = self.loss_fn(image_pred.float(), image_real.float())
return loss
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
if not self.initialized:
self.initialize(pipe.device)
with torch.no_grad():
pipe.scheduler.set_timesteps(8)
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
loss = loss_1 + loss_2
return loss

View File

@@ -1,70 +0,0 @@
import argparse
def add_dataset_base_config(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
return parser
def add_image_size_config(parser: argparse.ArgumentParser):
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
return parser
def add_video_size_config(parser: argparse.ArgumentParser):
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
return parser
def add_model_config(parser: argparse.ArgumentParser):
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
return parser
def add_training_config(parser: argparse.ArgumentParser):
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
return parser
def add_output_config(parser: argparse.ArgumentParser):
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
return parser
def add_lora_config(parser: argparse.ArgumentParser):
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
return parser
def add_gradient_config(parser: argparse.ArgumentParser):
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser
def add_general_config(parser: argparse.ArgumentParser):
parser = add_dataset_base_config(parser)
parser = add_model_config(parser)
parser = add_training_config(parser)
parser = add_output_config(parser)
parser = add_lora_config(parser)
parser = add_gradient_config(parser)
return parser

View File

@@ -1,72 +0,0 @@
import os, torch
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
torch.save(data, save_path)

View File

@@ -1,263 +0,0 @@
import torch, json, os
from ..core import ModelConfig, load_state_dict
from ..utils.controlnet import ControlNetInput
from peft import LoraConfig, inject_adapter_in_model
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
def to(self, *args, **kwargs):
for name, model in self.named_children():
model.to(*args, **kwargs)
return self
def trainable_modules(self):
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
return trainable_modules
def trainable_param_names(self):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
if lora_alpha is None:
lora_alpha = lora_rank
if isinstance(target_modules, list) and len(target_modules) == 1:
target_modules = target_modules[0]
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
model = inject_adapter_in_model(lora_config, model)
if upcast_dtype is not None:
for param in model.parameters():
if param.requires_grad:
param.data = param.to(upcast_dtype)
return model
def mapping_lora_state_dict(self, state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if "lora_A.weight" in key or "lora_B.weight" in key:
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
new_state_dict[new_key] = value
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
new_state_dict[key] = value
return new_state_dict
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
trainable_param_names = self.trainable_param_names()
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
if remove_prefix is not None:
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith(remove_prefix):
name = name[len(remove_prefix):]
state_dict_[name] = param
state_dict = state_dict_
return state_dict
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
if data is None:
return data
elif isinstance(data, torch.Tensor):
data = data.to(device)
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
data = data.to(torch_float_dtype)
return data
elif isinstance(data, tuple):
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, list):
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
return data
elif isinstance(data, dict):
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
return data
else:
return data
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
if fp8:
return {
"offload_dtype": torch.float8_e4m3fn,
"offload_device": device,
"onload_dtype": torch.float8_e4m3fn,
"onload_device": device,
"preparing_dtype": torch.float8_e4m3fn,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
}
elif offload:
return {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": device,
"computation_dtype": torch.bfloat16,
"computation_device": device,
"clear_parameters": True,
}
else:
return {}
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
fp8_models = [] if fp8_models is None else fp8_models.split(",")
offload_models = [] if offload_models is None else offload_models.split(",")
model_configs = []
if model_paths is not None:
model_paths = json.loads(model_paths)
for path in model_paths:
vram_config = self.parse_vram_config(
fp8=path in fp8_models,
offload=path in offload_models,
device=device
)
model_configs.append(ModelConfig(path=path, **vram_config))
if model_id_with_origin_paths is not None:
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
for model_id_with_origin_path in model_id_with_origin_paths:
vram_config = self.parse_vram_config(
fp8=model_id_with_origin_path in fp8_models,
offload=model_id_with_origin_path in offload_models,
device=device
)
config = self.parse_path_or_model_id(model_id_with_origin_path)
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
return model_configs
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
if model_id_with_origin_path is None:
return default_value
elif os.path.exists(model_id_with_origin_path):
return ModelConfig(path=model_id_with_origin_path)
else:
if ":" not in model_id_with_origin_path:
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
split_id = model_id_with_origin_path.rfind(":")
model_id = model_id_with_origin_path[:split_id]
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
def auto_detect_lora_target_modules(
self,
model: torch.nn.Module,
search_for_linear=False,
linear_detector=lambda x: min(x.weight.shape) >= 512,
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
name_prefix="",
):
lora_target_modules = []
if search_for_linear:
for name, module in model.named_modules():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
if isinstance(module, torch.nn.Linear) and linear_detector(module):
lora_target_modules.append(module_name)
else:
for name, module in model.named_children():
module_name = name_prefix + ["", "."][name_prefix != ""] + name
lora_target_modules += self.auto_detect_lora_target_modules(
module,
search_for_linear=block_list_detector(module),
linear_detector=linear_detector,
block_list_detector=block_list_detector,
name_prefix=module_name,
)
return lora_target_modules
def parse_lora_target_modules(self, model, lora_target_modules):
if lora_target_modules == "":
print("No LoRA target modules specified. The framework will automatically search for them.")
lora_target_modules = self.auto_detect_lora_target_modules(model)
print(f"LoRA will be patched at {lora_target_modules}.")
else:
lora_target_modules = lora_target_modules.split(",")
return lora_target_modules
def switch_pipe_to_training_mode(
self,
pipe,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
task="sft",
):
# Scheduler
pipe.scheduler.set_timesteps(1000, training=True)
# Freeze untrainable models
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
# Preset LoRA
if preset_lora_path is not None:
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
# FP8
# FP8 relies on a model-specific memory management scheme.
# It is delegated to the subclass.
# Add LoRA to the base models
if lora_base_model is not None and not task.endswith(":data_process"):
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
return
model = self.add_lora_to_model(
getattr(pipe, lora_base_model),
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
lora_rank=lora_rank,
upcast_dtype=pipe.torch_dtype,
)
if lora_checkpoint is not None:
state_dict = load_state_dict(lora_checkpoint)
state_dict = self.mapping_lora_state_dict(state_dict)
load_result = model.load_state_dict(state_dict, strict=False)
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
models_require_backward = []
if trainable_models is not None:
models_require_backward += trainable_models.split(",")
if lora_base_model is not None:
models_require_backward += [lora_base_model]
if task.endswith(":data_process"):
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
elif task.endswith(":train"):
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
return pipe
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
controlnet_keys_map = (
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
("controlnet_", "controlnet_inputs"),
)
controlnet_inputs = {}
for extra_input in extra_inputs:
for prefix, name in controlnet_keys_map:
if extra_input.startswith(prefix):
if name not in controlnet_inputs:
controlnet_inputs[name] = {}
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
break
else:
inputs_shared[extra_input] = data[extra_input]
for name, params in controlnet_inputs.items():
inputs_shared[name] = [ControlNetInput(**params)]
return inputs_shared

View File

@@ -0,0 +1,137 @@
import torch
from einops import repeat
from PIL import Image
import numpy as np
class ResidualDenseBlock(torch.nn.Module):
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(torch.nn.Module):
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(torch.nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
feat = self.lrelu(self.conv_up1(feat))
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
feat = self.lrelu(self.conv_up2(feat))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
@staticmethod
def state_dict_converter():
return RRDBNetStateDictConverter()
class RRDBNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
return state_dict, {"upcast_to_float32": True}
class ESRGAN(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
@staticmethod
def from_model_manager(model_manager):
return ESRGAN(model_manager.fetch_model("esrgan"))
def process_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
return image
def process_images(self, images):
images = [self.process_image(image) for image in images]
images = torch.stack(images)
return images
def decode_images(self, images):
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
images = [Image.fromarray(image) for image in images]
return images
@torch.no_grad()
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
if not isinstance(images, list):
images = [images]
is_single_image = True
else:
is_single_image = False
# Preprocess
input_tensor = self.process_images(images)
# Interpolate
output_tensor = []
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(
device=self.model.conv_first.weight.device,
dtype=self.model.conv_first.weight.dtype)
batch_output_tensor = self.model(batch_input_tensor)
output_tensor.append(batch_output_tensor.cpu())
# Output
output_tensor = torch.concat(output_tensor, dim=0)
# To images
output_images = self.decode_images(output_tensor)
if is_single_image:
output_images = output_images[0]
return output_images

View File

@@ -0,0 +1,63 @@
from .runners.fast import TableManager, PyramidPatchMatcher
from PIL import Image
import numpy as np
import cupy as cp
class FastBlendSmoother:
def __init__(self):
self.batch_size = 8
self.window_size = 64
self.ebsynth_config = {
"minimum_patch_size": 5,
"threads_per_block": 8,
"num_iter": 5,
"gpu_id": 0,
"guide_weight": 10.0,
"initialize": "identity",
"tracking_window_size": 0,
}
@staticmethod
def from_model_manager(model_manager):
# TODO: fetch GPU ID from model_manager
return FastBlendSmoother()
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
frames_guide = [np.array(frame) for frame in frames_guide]
frames_style = [np.array(frame) for frame in frames_style]
table_manager = TableManager()
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# left part
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
table_l = table_manager.remapping_table_to_blending_table(table_l)
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
# right part
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
table_r = table_manager.remapping_table_to_blending_table(table_r)
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
# merge
frames = []
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
weight_m = -1
weight = weight_l + weight_m + weight_r
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
frames.append(frame)
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
return frames
def __call__(self, rendered_frames, original_frames=None, **kwargs):
frames = self.run(
original_frames, rendered_frames,
self.batch_size, self.window_size, self.ebsynth_config
)
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
return frames

View File

@@ -0,0 +1,397 @@
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
from .data import VideoData, get_video_fps, save_video, search_for_images
import os
import gradio as gr
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
frames_guide = VideoData(video_guide, video_guide_folder)
frames_style = VideoData(video_style, video_style_folder)
message = ""
if len(frames_guide) < len(frames_style):
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
frames_style.set_length(len(frames_guide))
elif len(frames_guide) > len(frames_style):
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
frames_guide.set_length(len(frames_style))
height_guide, width_guide = frames_guide.shape()
height_style, width_style = frames_style.shape()
if height_guide != height_style or width_guide != width_style:
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
frames_style.set_shape(height_guide, width_guide)
return frames_guide, frames_style, message
def smooth_video(
video_guide,
video_guide_folder,
video_style,
video_style_folder,
mode,
window_size,
batch_size,
tracking_window_size,
output_path,
fps,
minimum_patch_size,
num_iter,
guide_weight,
initialize,
progress = None,
):
# input
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
if len(message) > 0:
print(message)
# output
if output_path == "":
if video_style is None:
output_path = os.path.join(video_style_folder, "output")
else:
output_path = os.path.join(os.path.split(video_style)[0], "output")
os.makedirs(output_path, exist_ok=True)
print("No valid output_path. Your video will be saved here:", output_path)
elif not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
print("Your video will be saved here:", output_path)
frames_path = os.path.join(output_path, "frames")
video_path = os.path.join(output_path, "video.mp4")
os.makedirs(frames_path, exist_ok=True)
# process
if mode == "Fast" or mode == "Balanced":
tracking_window_size = 0
ebsynth_config = {
"minimum_patch_size": minimum_patch_size,
"threads_per_block": 8,
"num_iter": num_iter,
"gpu_id": 0,
"guide_weight": guide_weight,
"initialize": initialize,
"tracking_window_size": tracking_window_size,
}
if mode == "Fast":
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
elif mode == "Balanced":
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
elif mode == "Accurate":
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
# output
try:
fps = int(fps)
except:
fps = get_video_fps(video_style) if video_style is not None else 30
print("Fps:", fps)
print("Saving video...")
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
print("Success!")
print("Your frames are here:", frames_path)
print("Your video is here:", video_path)
return output_path, fps, video_path
class KeyFrameMatcher:
def __init__(self):
pass
def extract_number_from_filename(self, file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
if number != -1:
result.append(number)
result = tuple(result)
return result
def extract_number_from_filenames(self, file_names):
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
min_length = min(len(i) for i in numbers)
for i in range(min_length-1, -1, -1):
if len(set(number[i] for number in numbers))==len(file_names):
return [number[i] for number in numbers]
return list(range(len(file_names)))
def match_using_filename(self, file_names_a, file_names_b):
file_names_b_set = set(file_names_b)
matched_file_name = []
for file_name in file_names_a:
if file_name not in file_names_b_set:
matched_file_name.append(None)
else:
matched_file_name.append(file_name)
return matched_file_name
def match_using_numbers(self, file_names_a, file_names_b):
numbers_a = self.extract_number_from_filenames(file_names_a)
numbers_b = self.extract_number_from_filenames(file_names_b)
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
matched_file_name = []
for number in numbers_a:
if number in numbers_b_dict:
matched_file_name.append(numbers_b_dict[number])
else:
matched_file_name.append(None)
return matched_file_name
def match_filenames(self, file_names_a, file_names_b):
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
if sum([i is not None for i in matched_file_name]) > 0:
return matched_file_name
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
return matched_file_name
def detect_frames(frames_path, keyframes_path):
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
return "Please input the directory of guide video and rendered frames"
elif not os.path.exists(frames_path):
return "Please input the directory of guide video"
elif not os.path.exists(keyframes_path):
return "Please input the directory of rendered frames"
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
if len(frames)==0:
return f"No images detected in {frames_path}"
if len(keyframes)==0:
return f"No images detected in {keyframes_path}"
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
max_filename_length = max([len(i) for i in frames])
if sum([i is not None for i in matched_keyframes])==0:
message = ""
for frame, matched_keyframe in zip(frames, matched_keyframes):
message += frame + " " * (max_filename_length - len(frame) + 1)
message += "--> No matched keyframes\n"
else:
message = ""
for frame, matched_keyframe in zip(frames, matched_keyframes):
message += frame + " " * (max_filename_length - len(frame) + 1)
if matched_keyframe is None:
message += "--> [to be rendered]\n"
else:
message += f"--> {matched_keyframe}\n"
return message
def check_input_for_interpolating(frames_path, keyframes_path):
# search for images
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
# match frames
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
frames_guide = VideoData(None, frames_path)
frames_style = VideoData(None, keyframes_path, file_list=file_list)
# match shape
message = ""
height_guide, width_guide = frames_guide.shape()
height_style, width_style = frames_style.shape()
if height_guide != height_style or width_guide != width_style:
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
frames_style.set_shape(height_guide, width_guide)
return frames_guide, frames_style, index_style, message
def interpolate_video(
frames_path,
keyframes_path,
output_path,
fps,
batch_size,
tracking_window_size,
minimum_patch_size,
num_iter,
guide_weight,
initialize,
progress = None,
):
# input
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
if len(message) > 0:
print(message)
# output
if output_path == "":
output_path = os.path.join(keyframes_path, "output")
os.makedirs(output_path, exist_ok=True)
print("No valid output_path. Your video will be saved here:", output_path)
elif not os.path.exists(output_path):
os.makedirs(output_path, exist_ok=True)
print("Your video will be saved here:", output_path)
output_frames_path = os.path.join(output_path, "frames")
output_video_path = os.path.join(output_path, "video.mp4")
os.makedirs(output_frames_path, exist_ok=True)
# process
ebsynth_config = {
"minimum_patch_size": minimum_patch_size,
"threads_per_block": 8,
"num_iter": num_iter,
"gpu_id": 0,
"guide_weight": guide_weight,
"initialize": initialize,
"tracking_window_size": tracking_window_size
}
if len(index_style)==1:
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
else:
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
try:
fps = int(fps)
except:
fps = 30
print("Fps:", fps)
print("Saving video...")
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
print("Success!")
print("Your frames are here:", output_frames_path)
print("Your video is here:", video_path)
return output_path, fps, video_path
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as ui_component:
with gr.Tab("Blend"):
gr.Markdown("""
# Blend
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
""")
with gr.Row():
with gr.Column():
with gr.Tab("Guide video"):
video_guide = gr.Video(label="Guide video")
with gr.Tab("Guide video (images format)"):
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
with gr.Column():
with gr.Tab("Style video"):
video_style = gr.Video(label="Style video")
with gr.Tab("Style video (images format)"):
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
with gr.Column():
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
btn = gr.Button(value="Blend")
with gr.Row():
with gr.Column():
gr.Markdown("# Settings")
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
gr.Markdown("## Advanced Settings")
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
with gr.Column():
gr.Markdown("""
# Reference
* Output directory: the directory to save the video.
* Inference mode
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|-|-|-|-|-|-|
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
* Number of iterations: the number of iterations of patch matching. (Default: 5)
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
""")
btn.click(
smooth_video,
inputs=[
video_guide,
video_guide_folder,
video_style,
video_style_folder,
mode,
window_size,
batch_size,
tracking_window_size,
output_path,
fps,
minimum_patch_size,
num_iter,
guide_weight,
initialize
],
outputs=[output_path, fps, video_output]
)
with gr.Tab("Interpolate"):
gr.Markdown("""
# Interpolate
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
""")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
with gr.Column():
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
with gr.Row():
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
with gr.Column():
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
btn_ = gr.Button(value="Interpolate")
with gr.Row():
with gr.Column():
gr.Markdown("# Settings")
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
gr.Markdown("## Advanced Settings")
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
with gr.Column():
gr.Markdown("""
# Reference
* Output directory: the directory to save the video.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
* Number of iterations: the number of iterations of patch matching. (Default: 5)
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
""")
btn_.click(
interpolate_video,
inputs=[
video_guide_folder_,
rendered_keyframes_,
output_path_,
fps_,
batch_size_,
tracking_window_size_,
minimum_patch_size_,
num_iter_,
guide_weight_,
initialize_,
],
outputs=[output_path_, fps_, video_output_]
)
return [(ui_component, "FastBlend", "FastBlend_ui")]

View File

@@ -0,0 +1,119 @@
import cupy as cp
remapping_kernel = cp.RawKernel(r'''
extern "C" __global__
void remap(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_style,
const int* nnf,
float* target_style
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x >= height or y >= width) return;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
const int min_px = x < r ? -x : -r;
const int max_px = x + r > height - 1 ? height - 1 - x : r;
const int min_py = y < r ? -y : -r;
const int max_py = y + r > width - 1 ? width - 1 - y : r;
int num = 0;
for (int px = min_px; px <= max_px; px++){
for (int py = min_py; py <= max_py; py++){
const int nid = (x + px) * width + y + py;
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
num++;
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
}
}
}
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] /= num;
}
}
''', 'remap')
patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source,
const int* nnf,
const float* target,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'patch_error')
pairwise_patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void pairwise_patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_a,
const int* nnf_a,
const float* source_b,
const int* nnf_b,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
const int x_a = nnf_a[z_nnf + 0];
const int y_a = nnf_a[z_nnf + 1];
const int x_b = nnf_b[z_nnf + 0];
const int y_b = nnf_b[z_nnf + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'pairwise_patch_error')

View File

@@ -0,0 +1,146 @@
import imageio, os
import numpy as np
from PIL import Image
def read_video(file_name):
reader = imageio.get_reader(file_name)
video = []
for frame in reader:
frame = np.array(frame)
video.append(frame)
reader.close()
return video
def get_video_fps(file_name):
reader = imageio.get_reader(file_name)
fps = reader.get_meta_data()["fps"]
reader.close()
return fps
def save_video(frames_path, video_path, num_frames, fps):
writer = imageio.get_writer(video_path, fps=fps, quality=9)
for i in range(num_frames):
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
writer.append_data(frame)
writer.close()
return video_path
class LowMemoryVideo:
def __init__(self, file_name):
self.reader = imageio.get_reader(file_name)
def __len__(self):
return self.reader.count_frames()
def __getitem__(self, item):
return np.array(self.reader.get_data(item))
def __del__(self):
self.reader.close()
def split_file_name(file_name):
result = []
number = -1
for i in file_name:
if ord(i)>=ord("0") and ord(i)<=ord("9"):
if number == -1:
number = 0
number = number*10 + ord(i) - ord("0")
else:
if number != -1:
result.append(number)
number = -1
result.append(i)
if number != -1:
result.append(number)
result = tuple(result)
return result
def search_for_images(folder):
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
file_list = [i[1] for i in sorted(file_list)]
file_list = [os.path.join(folder, i) for i in file_list]
return file_list
def read_images(folder):
file_list = search_for_images(folder)
frames = [np.array(Image.open(i)) for i in file_list]
return frames
class LowMemoryImageFolder:
def __init__(self, folder, file_list=None):
if file_list is None:
self.file_list = search_for_images(folder)
else:
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
def __len__(self):
return len(self.file_list)
def __getitem__(self, item):
return np.array(Image.open(self.file_list[item]))
def __del__(self):
pass
class VideoData:
def __init__(self, video_file, image_folder, **kwargs):
if video_file is not None:
self.data_type = "video"
self.data = LowMemoryVideo(video_file, **kwargs)
elif image_folder is not None:
self.data_type = "images"
self.data = LowMemoryImageFolder(image_folder, **kwargs)
else:
raise ValueError("Cannot open video or image folder")
self.length = None
self.height = None
self.width = None
def raw_data(self):
frames = []
for i in range(self.__len__()):
frames.append(self.__getitem__(i))
return frames
def set_length(self, length):
self.length = length
def set_shape(self, height, width):
self.height = height
self.width = width
def __len__(self):
if self.length is None:
return len(self.data)
else:
return self.length
def shape(self):
if self.height is not None and self.width is not None:
return self.height, self.width
else:
height, width, _ = self.__getitem__(0).shape
return height, width
def __getitem__(self, item):
frame = self.data.__getitem__(item)
height, width, _ = frame.shape
if self.height is not None and self.width is not None:
if self.height != height or self.width != width:
frame = Image.fromarray(frame).resize((self.width, self.height))
frame = np.array(frame)
return frame
def __del__(self):
pass

View File

@@ -0,0 +1,298 @@
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
import numpy as np
import cupy as cp
import cv2
class PatchMatcher:
def __init__(
self, height, width, channel, minimum_patch_size,
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
random_search_steps=3, random_search_range=4,
use_mean_target_style=False, use_pairwise_patch_error=False,
tracking_window_size=0
):
self.height = height
self.width = width
self.channel = channel
self.minimum_patch_size = minimum_patch_size
self.threads_per_block = threads_per_block
self.num_iter = num_iter
self.gpu_id = gpu_id
self.guide_weight = guide_weight
self.random_search_steps = random_search_steps
self.random_search_range = random_search_range
self.use_mean_target_style = use_mean_target_style
self.use_pairwise_patch_error = use_pairwise_patch_error
self.tracking_window_size = tracking_window_size
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
self.pad_size = self.patch_size_list[0] // 2
self.grid = (
(height + threads_per_block - 1) // threads_per_block,
(width + threads_per_block - 1) // threads_per_block
)
self.block = (threads_per_block, threads_per_block)
def pad_image(self, image):
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
def unpad_image(self, image):
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
def apply_nnf_to_image(self, nnf, source):
batch_size = source.shape[0]
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
remapping_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
)
return target
def get_patch_error(self, source, nnf, target):
batch_size = source.shape[0]
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
patch_error_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
)
return error
def get_pairwise_patch_error(self, source, nnf):
batch_size = source.shape[0]//2
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
pairwise_patch_error_kernel(
self.grid + (batch_size,),
self.block,
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
)
error = error.repeat(2, axis=0)
return error
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
if self.use_mean_target_style:
target_style = self.apply_nnf_to_image(nnf, source_style)
target_style = target_style.mean(axis=0, keepdims=True)
target_style = target_style.repeat(source_guide.shape[0], axis=0)
if self.use_pairwise_patch_error:
error_style = self.get_pairwise_patch_error(source_style, nnf)
else:
error_style = self.get_patch_error(source_style, nnf, target_style)
error = error_guide * self.guide_weight + error_style
return error
def clamp_bound(self, nnf):
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
return nnf
def random_step(self, nnf, r):
batch_size = nnf.shape[0]
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
upd_nnf = self.clamp_bound(nnf + step)
return upd_nnf
def neighboor_step(self, nnf, d):
if d==0:
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
upd_nnf[:, :, :, 0] += 1
elif d==1:
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
upd_nnf[:, :, :, 1] += 1
elif d==2:
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
upd_nnf[:, :, :, 0] -= 1
elif d==3:
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
upd_nnf[:, :, :, 1] -= 1
upd_nnf = self.clamp_bound(upd_nnf)
return upd_nnf
def shift_nnf(self, nnf, d):
if d>0:
d = min(nnf.shape[0], d)
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
else:
d = max(-nnf.shape[0], d)
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
return upd_nnf
def track_step(self, nnf, d):
if self.use_pairwise_patch_error:
upd_nnf = cp.zeros_like(nnf)
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
else:
upd_nnf = self.shift_nnf(nnf, d)
return upd_nnf
def C(self, n, m):
# not used
c = 1
for i in range(1, n+1):
c *= i
for i in range(1, m+1):
c //= i
for i in range(1, n-m+1):
c //= i
return c
def bezier_step(self, nnf, r):
# not used
n = r * 2 - 1
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
if d>0:
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
elif d<0:
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
return upd_nnf
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
upd_idx = (upd_err < err)
nnf[upd_idx] = upd_nnf[upd_idx]
err[upd_idx] = upd_err[upd_idx]
return nnf, err
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
for d in cp.random.permutation(4):
upd_nnf = self.neighboor_step(nnf, d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
for i in range(self.random_search_steps):
upd_nnf = self.random_step(nnf, self.random_search_range)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
for d in range(1, self.tracking_window_size + 1):
upd_nnf = self.track_step(nnf, d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
upd_nnf = self.track_step(nnf, -d)
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
return nnf, err
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
return nnf, err
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
with cp.cuda.Device(self.gpu_id):
source_guide = self.pad_image(source_guide)
target_guide = self.pad_image(target_guide)
source_style = self.pad_image(source_style)
for it in range(self.num_iter):
self.patch_size = self.patch_size_list[it]
target_style = self.apply_nnf_to_image(nnf, source_style)
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
return nnf, target_style
class PyramidPatchMatcher:
def __init__(
self, image_height, image_width, channel, minimum_patch_size,
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
use_mean_target_style=False, use_pairwise_patch_error=False,
tracking_window_size=0,
initialize="identity"
):
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
self.pyramid_heights = []
self.pyramid_widths = []
self.patch_matchers = []
self.minimum_patch_size = minimum_patch_size
self.num_iter = num_iter
self.gpu_id = gpu_id
self.initialize = initialize
for level in range(self.pyramid_level):
height = image_height//(2**(self.pyramid_level - 1 - level))
width = image_width//(2**(self.pyramid_level - 1 - level))
self.pyramid_heights.append(height)
self.pyramid_widths.append(width)
self.patch_matchers.append(PatchMatcher(
height, width, channel, minimum_patch_size=minimum_patch_size,
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
tracking_window_size=tracking_window_size
))
def resample_image(self, images, level):
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
images = images.get()
images_resample = []
for image in images:
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
images_resample.append(image_resample)
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
return images_resample
def initialize_nnf(self, batch_size):
if self.initialize == "random":
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
nnf = cp.stack([
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
], axis=3)
elif self.initialize == "identity":
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
nnf = cp.stack([
cp.repeat(cp.arange(height), width).reshape(height, width),
cp.tile(cp.arange(width), height).reshape(height, width)
], axis=2)
nnf = cp.stack([nnf] * batch_size)
else:
raise NotImplementedError()
return nnf
def update_nnf(self, nnf, level):
# upscale
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
# check if scale is 2
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
nnf = nnf.get().astype(np.float32)
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
nnf = self.patch_matchers[level].clamp_bound(nnf)
return nnf
def apply_nnf_to_image(self, nnf, image):
with cp.cuda.Device(self.gpu_id):
image = self.patch_matchers[-1].pad_image(image)
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
return image
def estimate_nnf(self, source_guide, target_guide, source_style):
with cp.cuda.Device(self.gpu_id):
if not isinstance(source_guide, cp.ndarray):
source_guide = cp.array(source_guide, dtype=cp.float32)
if not isinstance(target_guide, cp.ndarray):
target_guide = cp.array(target_guide, dtype=cp.float32)
if not isinstance(source_style, cp.ndarray):
source_style = cp.array(source_style, dtype=cp.float32)
for level in range(self.pyramid_level):
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
source_guide_ = self.resample_image(source_guide, level)
target_guide_ = self.resample_image(target_guide, level)
source_style_ = self.resample_image(source_style, level)
nnf, target_style = self.patch_matchers[level].estimate_nnf(
source_guide_, target_guide_, source_style_, nnf
)
return nnf.get(), target_style.get()

View File

@@ -0,0 +1,4 @@
from .accurate import AccurateModeRunner
from .fast import FastModeRunner
from .balanced import BalancedModeRunner
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner

View File

@@ -0,0 +1,35 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class AccurateModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
use_mean_target_style=True,
**ebsynth_config
)
# run
n = len(frames_style)
for target in tqdm(range(n), desc=desc):
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
remapped_frames = []
for i in range(l, r, batch_size):
j = min(i + batch_size, r)
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
target_guide = np.stack([frames_guide[target]] * (j - i))
source_style = np.stack([frames_style[source] for source in range(i, j)])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
remapped_frames.append(target_style)
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
frame = frame.clip(0, 255).astype("uint8")
if save_path is not None:
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))

View File

@@ -0,0 +1,46 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class BalancedModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# tasks
n = len(frames_style)
tasks = []
for target in range(n):
for source in range(target - window_size, target + window_size + 1):
if source >= 0 and source < n and source != target:
tasks.append((source, target))
# run
frames = [(None, 1) for i in range(n)]
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for (source, target), result in zip(tasks_batch, target_style):
frame, weight = frames[target]
if frame is None:
frame = frames_style[target]
frames[target] = (
frame * (weight / (weight + 1)) + result / (weight + 1),
weight + 1
)
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
frame = frame.clip(0, 255).astype("uint8")
if save_path is not None:
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
frames[target] = (None, 1)

View File

@@ -0,0 +1,141 @@
from ..patch_match import PyramidPatchMatcher
import functools, os
import numpy as np
from PIL import Image
from tqdm import tqdm
class TableManager:
def __init__(self):
pass
def task_list(self, n):
tasks = []
max_level = 1
while (1<<max_level)<=n:
max_level += 1
for i in range(n):
j = i
for level in range(max_level):
if i&(1<<level):
continue
j |= 1<<level
if j>=n:
break
meta_data = {
"source": i,
"target": j,
"level": level + 1
}
tasks.append(meta_data)
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
return tasks
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
n = len(frames_guide)
tasks = self.task_list(n)
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for task, result in zip(tasks_batch, target_style):
target, level = task["target"], task["level"]
if len(remapping_table[target])==level:
remapping_table[target].append((result, 1))
else:
frame, weight = remapping_table[target][level]
remapping_table[target][level] = (
frame * (weight / (weight + 1)) + result / (weight + 1),
weight + 1
)
return remapping_table
def remapping_table_to_blending_table(self, table):
for i in range(len(table)):
for j in range(1, len(table[i])):
frame_1, weight_1 = table[i][j-1]
frame_2, weight_2 = table[i][j]
frame = (frame_1 + frame_2) / 2
weight = weight_1 + weight_2
table[i][j] = (frame, weight)
return table
def tree_query(self, leftbound, rightbound):
node_list = []
node_index = rightbound
while node_index>=leftbound:
node_level = 0
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
node_level += 1
node_list.append((node_index, node_level))
node_index -= 1<<node_level
return node_list
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
n = len(blending_table)
tasks = []
frames_result = []
for target in range(n):
node_list = self.tree_query(max(target-window_size, 0), target)
for source, level in node_list:
if source!=target:
meta_data = {
"source": source,
"target": target,
"level": level
}
tasks.append(meta_data)
else:
frames_result.append(blending_table[target][level])
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for task, frame_2 in zip(tasks_batch, target_style):
source, target, level = task["source"], task["target"], task["level"]
frame_1, weight_1 = frames_result[target]
weight_2 = blending_table[source][level][1]
weight = weight_1 + weight_2
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
frames_result[target] = (frame, weight)
return frames_result
class FastModeRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
frames_guide = frames_guide.raw_data()
frames_style = frames_style.raw_data()
table_manager = TableManager()
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
**ebsynth_config
)
# left part
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
table_l = table_manager.remapping_table_to_blending_table(table_l)
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
# right part
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
table_r = table_manager.remapping_table_to_blending_table(table_r)
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
# merge
frames = []
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
weight_m = -1
weight = weight_l + weight_m + weight_r
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
frames.append(frame)
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
if save_path is not None:
for target, frame in enumerate(frames):
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))

View File

@@ -0,0 +1,121 @@
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
class InterpolationModeRunner:
def __init__(self):
pass
def get_index_dict(self, index_style):
index_dict = {}
for i, index in enumerate(index_style):
index_dict[index] = i
return index_dict
def get_weight(self, l, m, r):
weight_l, weight_r = abs(m - r), abs(m - l)
if weight_l + weight_r == 0:
weight_l, weight_r = 0.5, 0.5
else:
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
return weight_l, weight_r
def get_task_group(self, index_style, n):
task_group = []
index_style = sorted(index_style)
# first frame
if index_style[0]>0:
tasks = []
for m in range(index_style[0]):
tasks.append((index_style[0], m, index_style[0]))
task_group.append(tasks)
# middle frames
for l, r in zip(index_style[:-1], index_style[1:]):
tasks = []
for m in range(l, r):
tasks.append((l, m, r))
task_group.append(tasks)
# last frame
tasks = []
for m in range(index_style[-1], n):
tasks.append((index_style[-1], m, index_style[-1]))
task_group.append(tasks)
return task_group
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
patch_match_engine = PyramidPatchMatcher(
image_height=frames_style[0].shape[0],
image_width=frames_style[0].shape[1],
channel=3,
use_mean_target_style=False,
use_pairwise_patch_error=True,
**ebsynth_config
)
# task
index_dict = self.get_index_dict(index_style)
task_group = self.get_task_group(index_style, len(frames_guide))
# run
for tasks in task_group:
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
source_guide, target_guide, source_style = [], [], []
for l, m, r in tasks_batch:
# l -> m
source_guide.append(frames_guide[l])
target_guide.append(frames_guide[m])
source_style.append(frames_style[index_dict[l]])
# r -> m
source_guide.append(frames_guide[r])
target_guide.append(frames_guide[m])
source_style.append(frames_style[index_dict[r]])
source_guide = np.stack(source_guide)
target_guide = np.stack(target_guide)
source_style = np.stack(source_style)
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
if save_path is not None:
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
weight_l, weight_r = self.get_weight(l, m, r)
frame = frame_l * weight_l + frame_r * weight_r
frame = frame.clip(0, 255).astype("uint8")
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
class InterpolationModeSingleFrameRunner:
def __init__(self):
pass
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
# check input
tracking_window_size = ebsynth_config["tracking_window_size"]
if tracking_window_size * 2 >= batch_size:
raise ValueError("batch_size should be larger than track_window_size * 2")
frame_style = frames_style[0]
frame_guide = frames_guide[index_style[0]]
patch_match_engine = PyramidPatchMatcher(
image_height=frame_style.shape[0],
image_width=frame_style.shape[1],
channel=3,
**ebsynth_config
)
# run
frame_id, n = 0, len(frames_guide)
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
if i + batch_size > n:
l, r = max(n - batch_size, 0), n
else:
l, r = i, i + batch_size
source_guide = np.stack([frame_guide] * (r-l))
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
source_style = np.stack([frame_style] * (r-l))
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
for i, frame in zip(range(l, r), target_style):
if i==frame_id:
frame = frame.clip(0, 255).astype("uint8")
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
frame_id += 1
if r < n and r-frame_id <= tracking_window_size:
break

View File

@@ -0,0 +1,242 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
def warp(tenInput, tenFlow, device):
backwarp_tenGrid = {}
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat(
[tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
def forward(self, x, flow, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
feat = self.conv0(torch.cat((x, flow), 1))
feat = self.convblock0(feat) + feat
feat = self.convblock1(feat) + feat
feat = self.convblock2(feat) + feat
feat = self.convblock3(feat) + feat
flow = self.conv1(feat)
mask = self.conv2(feat)
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
return flow, mask
class IFNet(nn.Module):
def __init__(self, **kwargs):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90)
self.block2 = IFBlock(7+4, c=90)
self.block_tea = IFBlock(10+4, c=90)
def forward(self, x, scale_list=[4, 2, 1], training=False):
if training == False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = (x[:, :4]).detach() * 0
mask = (x[:, :1]).detach() * 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
mask = mask + (m0 + (-m1)) / 2
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2], device=x.device)
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
merged.append((warped_img0, warped_img1))
'''
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, 1:4] * 2 - 1
'''
for i in range(3):
mask_list[i] = torch.sigmoid(mask_list[i])
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
return flow_list, mask_list[2], merged
@staticmethod
def state_dict_converter():
return IFNetStateDictConverter()
class IFNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
class RIFEInterpolater:
def __init__(self, model, device="cuda"):
self.model = model
self.device = device
# IFNet only does not support float16
self.torch_dtype = torch.float32
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_image(self, image):
width, height = image.size
if width % 32 != 0 or height % 32 != 0:
width = (width + 31) // 32
height = (height + 31) // 32
image = image.resize((width, height))
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
return image
def process_images(self, images):
images = [self.process_image(image) for image in images]
images = torch.stack(images)
return images
def decode_images(self, images):
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
images = [Image.fromarray(image) for image in images]
return images
def add_interpolated_images(self, images, interpolated_images):
output_images = []
for image, interpolated_image in zip(images, interpolated_images):
output_images.append(image)
output_images.append(interpolated_image)
output_images.append(images[-1])
return output_images
@torch.no_grad()
def interpolate_(self, images, scale=1.0):
input_tensor = self.process_images(images)
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
output_images = self.decode_images(merged[2].cpu())
if output_images[0].size != images[0].size:
output_images = [image.resize(images[0].size) for image in output_images]
return output_images
@torch.no_grad()
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
# Preprocess
processed_images = self.process_images(images)
for iter in range(num_iter):
# Input
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
# Interpolate
output_tensor = []
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
output_tensor.append(merged[2].cpu())
# Output
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
processed_images = self.add_interpolated_images(processed_images, output_tensor)
processed_images = torch.stack(processed_images)
# To images
output_images = self.decode_images(processed_images)
if output_images[0].size != images[0].size:
output_images = [image.resize(images[0].size) for image in output_images]
return output_images
class RIFESmoother(RIFEInterpolater):
def __init__(self, model, device="cuda"):
super(RIFESmoother, self).__init__(model, device=device)
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
output_tensor = []
for batch_id in range(0, input_tensor.shape[0], batch_size):
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
batch_input_tensor = input_tensor[batch_id: batch_id_]
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
output_tensor.append(merged[2].cpu())
output_tensor = torch.concat(output_tensor, dim=0)
return output_tensor
@torch.no_grad()
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
# Preprocess
processed_images = self.process_images(rendered_frames)
for iter in range(num_iter):
# Input
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
# Interpolate
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
# Blend
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
# Add to frames
processed_images[1:-1] = output_tensor
# To images
output_images = self.decode_images(processed_images)
if output_images[0].size != rendered_frames[0].size:
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
return output_images

View File

@@ -0,0 +1 @@
from .model_manager import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,89 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size = q.shape[0]
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if qkv_preprocessor is not None:
q, k, v = qkv_preprocessor(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
if ipadapter_kwargs is not None:
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)

408
diffsynth/models/cog_dit.py Normal file
View File

@@ -0,0 +1,408 @@
import torch
from einops import rearrange, repeat
from .sd3_dit import TimestepEmbeddings
from .attention import Attention
from .utils import load_state_dict_from_folder
from .tiler import TileWorker2Dto3D
import numpy as np
class CogPatchify(torch.nn.Module):
def __init__(self, dim_in, dim_out, patch_size) -> None:
super().__init__()
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
return hidden_states
class CogAdaLayerNorm(torch.nn.Module):
def __init__(self, dim, dim_cond, single=False):
super().__init__()
self.single = single
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
def forward(self, hidden_states, prompt_emb, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
return hidden_states
else:
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
return hidden_states, prompt_emb, gate_a, gate_b
class CogDiTBlock(torch.nn.Module):
def __init__(self, dim, dim_cond, num_heads):
super().__init__()
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def apply_rotary_emb(self, x, freqs_cis):
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
q = self.norm_q(q)
k = self.norm_k(k)
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
return q, k, v
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
# Attention
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
hidden_states, prompt_emb, time_emb
)
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attention_io = self.attn1(
attention_io,
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
)
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
# Feed forward
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
hidden_states, prompt_emb, time_emb
)
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_io = self.ff(ff_io)
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
return hidden_states, prompt_emb
class CogDiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.patchify = CogPatchify(16, 3072, 2)
self.time_embedder = TimestepEmbeddings(3072, 512)
self.context_embedder = torch.nn.Linear(4096, 3072)
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_3d_rotary_pos_embed(
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
):
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
):
grid_height = height // 2
grid_width = width // 2
base_size_width = 720 // (8 * 2)
base_size_height = 480 // (8 * 2)
grid_crops_coords = self.get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
embed_dim=64,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def build_mask(self, T, H, W, dtype, device, is_bound):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
B, C, T, H, W = hidden_states.shape
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = max(H - tile_size, 0), H
if w_ > W: w, w_ = max(W - tile_size, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
mask = self.build_mask(
value.shape[2], (hr-hl), (wr-wl),
hidden_states.dtype, hidden_states.device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
)
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
value[:, :, :, hl:hr, wl:wr] += model_output * mask
weight[:, :, :, hl:hr, wl:wr] += mask
value = value / weight
return value
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
if tiled:
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
model_input=hidden_states,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
)
num_frames, height, width = hidden_states.shape[-3:]
if image_rotary_emb is None:
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
hidden_states = self.patchify(hidden_states)
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, time_emb, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return CogDiTStateDictConverter()
@staticmethod
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
model = CogDiT().to(torch_dtype)
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
model.load_state_dict(state_dict)
return model
class CogDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"patch_embed.proj.weight": "patchify.proj.weight",
"patch_embed.proj.bias": "patchify.proj.bias",
"patch_embed.text_proj.weight": "context_embedder.weight",
"patch_embed.text_proj.bias": "context_embedder.bias",
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
"norm_final.weight": "norm_final.weight",
"norm_final.bias": "norm_final.bias",
"norm_out.linear.weight": "norm_out.linear.weight",
"norm_out.linear.bias": "norm_out.linear.bias",
"norm_out.norm.weight": "norm_out.norm.weight",
"norm_out.norm.bias": "norm_out.norm.bias",
"proj_out.weight": "proj_out.weight",
"proj_out.bias": "proj_out.bias",
}
suffix_dict = {
"norm1.linear.weight": "norm1.linear.weight",
"norm1.linear.bias": "norm1.linear.bias",
"norm1.norm.weight": "norm1.norm.weight",
"norm1.norm.bias": "norm1.norm.bias",
"attn1.norm_q.weight": "norm_q.weight",
"attn1.norm_q.bias": "norm_q.bias",
"attn1.norm_k.weight": "norm_k.weight",
"attn1.norm_k.bias": "norm_k.bias",
"attn1.to_q.weight": "attn1.to_q.weight",
"attn1.to_q.bias": "attn1.to_q.bias",
"attn1.to_k.weight": "attn1.to_k.weight",
"attn1.to_k.bias": "attn1.to_k.bias",
"attn1.to_v.weight": "attn1.to_v.weight",
"attn1.to_v.bias": "attn1.to_v.bias",
"attn1.to_out.0.weight": "attn1.to_out.weight",
"attn1.to_out.0.bias": "attn1.to_out.bias",
"norm2.linear.weight": "norm2.linear.weight",
"norm2.linear.bias": "norm2.linear.bias",
"norm2.norm.weight": "norm2.norm.weight",
"norm2.norm.bias": "norm2.norm.bias",
"ff.net.0.proj.weight": "ff.0.weight",
"ff.net.0.proj.bias": "ff.0.bias",
"ff.net.2.weight": "ff.2.weight",
"ff.net.2.bias": "ff.2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "patch_embed.proj.weight":
param = param.unsqueeze(2)
state_dict_[rename_dict[name]] = param
else:
names = name.split(".")
if names[0] == "transformer_blocks":
suffix = ".".join(names[2:])
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

518
diffsynth/models/cog_vae.py Normal file
View File

@@ -0,0 +1,518 @@
import torch
from einops import rearrange, repeat
from .tiler import TileWorker2Dto3D
class Downsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
class Upsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs
class CogVideoXSpatialNorm3D(torch.nn.Module):
def __init__(self, f_channels, zq_channels, groups):
super().__init__()
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
zq = torch.cat([z_first, z_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class Resnet3DBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
super().__init__()
self.nonlinearity = torch.nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
if in_channels != out_channels:
if use_conv_shortcut:
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
else:
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.conv_shortcut = lambda x: x
def forward(self, hidden_states, zq):
residual = hidden_states
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = hidden_states + self.conv_shortcut(residual)
return hidden_states
class CachedConv3d(torch.nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.cached_tensor = None
def clear_cache(self):
self.cached_tensor = None
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
if use_cache:
if self.cached_tensor is None:
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
input = torch.concat([self.cached_tensor, input], dim=2)
self.cached_tensor = input[:, :, -2:]
return super().forward(input)
class CogVAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Upsample3D(512, 512, compress_time=True),
Resnet3DBlock(512, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
])
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
sample = sample / self.scaling_factor
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.decode_small_video(x),
model_input=sample,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
progress_bar=progress_bar
)
else:
return self.decode_small_video(sample)
def decode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//2):
tl = i*2 + T%2 - (T%2 and i==0)
tr = i*2 + 2 + T%2
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEDecoderStateDictConverter()
class CogVAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Downsample3D(128, 128, compress_time=True),
Resnet3DBlock(128, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
])
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)[:, :16]
hidden_states = hidden_states * self.scaling_factor
return hidden_states
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.encode_small_video(x),
model_input=sample,
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
progress_bar=progress_bar
)
else:
return self.encode_small_video(sample)
def encode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//8):
t = i*8 + T%2 - (T%2 and i==0)
t_ = i*8 + 8 + T%2
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEEncoderStateDictConverter()
class CogVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"encoder.conv_in.conv.weight": "conv_in.weight",
"encoder.conv_in.conv.bias": "conv_in.bias",
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
"encoder.norm_out.weight": "norm_out.weight",
"encoder.norm_out.bias": "norm_out.bias",
"encoder.conv_out.conv.weight": "conv_out.weight",
"encoder.conv_out.conv.bias": "conv_out.bias",
}
prefix_dict = {
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
"encoder.mid_block.resnets.0.": "blocks.15.",
"encoder.mid_block.resnets.1.": "blocks.16.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
"norm1.weight": "norm1.weight",
"norm1.bias": "norm1.bias",
"norm2.weight": "norm2.weight",
"norm2.bias": "norm2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class CogVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"decoder.conv_in.conv.weight": "conv_in.weight",
"decoder.conv_in.conv.bias": "conv_in.bias",
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
"decoder.conv_out.conv.weight": "conv_out.weight",
"decoder.conv_out.conv.bias": "conv_out.bias"
}
prefix_dict = {
"decoder.mid_block.resnets.0.": "blocks.0.",
"decoder.mid_block.resnets.1.": "blocks.1.",
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,96 +0,0 @@
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
from ..core.device.npu_compatible_device import get_device_type
class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
config = DINOv3ViTConfig(
architectures = [
"DINOv3ViTModel"
],
attention_dropout = 0.0,
drop_path_rate = 0.0,
dtype = "float32",
hidden_act = "silu",
hidden_size = 4096,
image_size = 224,
initializer_range = 0.02,
intermediate_size = 8192,
key_bias = False,
layer_norm_eps = 1e-05,
layerscale_value = 1.0,
mlp_bias = True,
model_type = "dinov3_vit",
num_attention_heads = 32,
num_channels = 3,
num_hidden_layers = 40,
num_register_tokens = 4,
patch_size = 16,
pos_embed_jitter = None,
pos_embed_rescale = 2.0,
pos_embed_shift = None,
proj_bias = True,
query_bias = False,
rope_theta = 100.0,
transformers_version = "4.56.1",
use_gated_mlp = True,
value_bias = False
)
super().__init__(config)
self.processor = DINOv3ViTImageProcessorFast(
crop_size = None,
data_format = "channels_first",
default_to_square = True,
device = None,
disable_grouping = None,
do_center_crop = None,
do_convert_rgb = None,
do_normalize = True,
do_rescale = True,
do_resize = True,
image_mean = [
0.485,
0.456,
0.406
],
image_processor_type = "DINOv3ViTImageProcessorFast",
image_std = [
0.229,
0.224,
0.225
],
input_data_format = None,
resample = 2,
rescale_factor = 0.00392156862745098,
return_tensors = None,
size = {
"height": 224,
"width": 224
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None
head_mask = None
pixel_values = pixel_values.to(torch_dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
hidden_states = layer_module(
hidden_states,
attention_mask=layer_head_mask,
position_embeddings=position_embeddings,
)
sequence_output = self.norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return pooled_output

View File

@@ -0,0 +1,111 @@
from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
import os, shutil
from typing_extensions import Literal, TypeAlias
from typing import List
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
file_name = os.path.basename(origin_file_path)
if file_name in os.listdir(local_dir):
print(f" {file_name} has been already in {local_dir}.")
else:
print(f" Start downloading {os.path.join(local_dir, file_name)}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
file_name = os.path.basename(origin_file_path)
if file_name in os.listdir(local_dir):
print(f" {file_name} has been already in {local_dir}.")
else:
print(f" Start downloading {os.path.join(local_dir, file_name)}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, file_name)
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_customized_models(
model_id,
origin_file_path,
local_dir,
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
downloaded_files = []
for website in downloading_priority:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
print(f"Downloading models: {model_id_list}")
downloaded_files = []
load_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
# Parse model metadata
model_metadata = website_to_preset_models[website][model_id]
if isinstance(model_metadata, list):
file_data = model_metadata
else:
file_data = model_metadata.get("file_list", [])
# Try downloading the model from this website.
model_files = []
for model_id, origin_file_path, local_dir in file_data:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
model_files.append(file_to_download)
# If the model is successfully downloaded, break.
if len(model_files) > 0:
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
model_files = model_metadata["load_path"]
load_files.extend(model_files)
break
return load_files

File diff suppressed because it is too large Load Diff

View File

@@ -1,58 +0,0 @@
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
def __init__(self):
config = Mistral3Config(**{
"architectures": [
"Mistral3ForConditionalGeneration"
],
"dtype": "bfloat16",
"image_token_index": 10,
"model_type": "mistral3",
"multimodal_projector_bias": False,
"projector_hidden_act": "gelu",
"spatial_merge_size": 2,
"text_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 32768,
"max_position_embeddings": 131072,
"model_type": "mistral",
"num_attention_heads": 32,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"use_cache": True,
"vocab_size": 131072
},
"transformers_version": "4.57.1",
"vision_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 1024,
"image_size": 1540,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "pixtral",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"rope_theta": 10000.0
},
"vision_feature_layer": -1
})
super().__init__(config)
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -1,62 +1,9 @@
import torch
from einops import rearrange, repeat
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
# from .utils import hash_state_dict_keys, init_weights_on_device
from contextlib import contextmanager
from .utils import hash_state_dict_keys, init_weights_on_device
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
class FluxControlNet(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
@@ -155,9 +102,9 @@ class FluxControlNet(torch.nn.Module):
return controlnet_res_stack, controlnet_single_res_stack
# @staticmethod
# def state_dict_converter():
# return FluxControlNetStateDictConverter()
@staticmethod
def state_dict_converter():
return FluxControlNetStateDictConverter()
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
@@ -371,10 +318,6 @@ class FluxControlNetStateDictConverter:
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
else:
extra_kwargs = {}
return state_dict_, extra_kwargs

View File

@@ -1,7 +1,8 @@
import torch
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
from einops import rearrange
from .tiler import TileWorker
from .utils import init_weights_on_device
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size, num_tokens = hidden_states.shape[0:2]
@@ -268,29 +269,27 @@ class AdaLayerNormContinuous(torch.nn.Module):
def forward(self, x, conditioning):
emb = self.linear(self.silu(conditioning))
shift, scale = torch.chunk(emb, 2, dim=1)
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
return x
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
def __init__(self, disable_guidance_embedder=False):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(input_dim, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
self.input_dim = input_dim
def patchify(self, hidden_states):
@@ -320,6 +319,25 @@ class FluxDiT(torch.nn.Module):
return latent_image_ids
def tiled_forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=128, tile_stride=64,
**kwargs
):
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
hidden_states,
tile_size,
tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
return hidden_states
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
@@ -355,7 +373,8 @@ class FluxDiT(torch.nn.Module):
return attention_mask
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
repeat_dim = hidden_states.shape[1]
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
@@ -391,5 +410,330 @@ class FluxDiT(torch.nn.Module):
use_gradient_checkpointing=False,
**kwargs
):
# (Deprecated) The real forward is in `pipelines.flux_image`.
return None
if tiled:
return self.tiled_forward(
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.final_norm_out(hidden_states, conditioning)
hidden_states = self.final_proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device)
return weight, bias
class quantized_layer:
class Linear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
new_layer.weight = module.weight
if module.bias is not None:
new_layer.bias = module.bias
# del module
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
if hasattr(module,"quantized"):
continue
module.quantized= True
new_layer = quantized_layer.RMSNorm(module)
setattr(model, name, new_layer)
else:
replace_layer(module)
replace_layer(self)
@staticmethod
def state_dict_converter():
return FluxDiTStateDictConverter()
class FluxDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
"attn.to_k": "attn.a_to_k",
"attn.to_v": "attn.a_to_v",
"attn.to_out.0": "attn.a_to_out",
"attn.add_q_proj": "attn.b_to_q",
"attn.add_k_proj": "attn.b_to_k",
"attn.add_v_proj": "attn.b_to_v",
"attn.to_add_out": "attn.b_to_out",
"ff.net.0.proj": "ff_a.0",
"ff.net.2": "ff_a.2",
"ff_context.net.0.proj": "ff_b.0",
"ff_context.net.2": "ff_b.2",
"attn.norm_q": "attn.norm_q_a",
"attn.norm_k": "attn.norm_k_a",
"attn.norm_added_q": "attn.norm_q_b",
"attn.norm_added_k": "attn.norm_k_b",
}
rename_dict_single = {
"attn.to_q": "a_to_q",
"attn.to_k": "a_to_k",
"attn.to_v": "a_to_v",
"attn.norm_q": "norm_q_a",
"attn.norm_k": "norm_k_a",
"norm.linear": "norm.linear",
"proj_mlp": "proj_in_besides_attn",
"proj_out": "proj_out",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
middle = ".".join(names[2:])
if middle in rename_dict:
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
state_dict_[name_] = param
elif prefix.startswith("single_transformer_blocks."):
names = prefix.split(".")
names[0] = "single_blocks"
middle = ".".join(names[2:])
if middle in rename_dict_single:
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
state_dict_[name_] = param
else:
pass
else:
pass
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
state_dict_[name],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
state_dict_.pop(name)
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
"txt_in.bias": "context_embedder.bias",
"txt_in.weight": "context_embedder.weight",
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "final_proj_out.bias",
"final_layer.linear.weight": "final_proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
}
suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
"img_attn.proj.bias": "attn.a_to_out.bias",
"img_attn.proj.weight": "attn.a_to_out.weight",
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
"img_mlp.0.bias": "ff_a.0.bias",
"img_mlp.0.weight": "ff_a.0.weight",
"img_mlp.2.bias": "ff_a.2.bias",
"img_mlp.2.weight": "ff_a.2.weight",
"img_mod.lin.bias": "norm1_a.linear.bias",
"img_mod.lin.weight": "norm1_a.linear.weight",
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
"txt_attn.proj.bias": "attn.b_to_out.bias",
"txt_attn.proj.weight": "attn.b_to_out.weight",
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
"txt_mlp.0.bias": "ff_b.0.bias",
"txt_mlp.0.weight": "ff_b.0.weight",
"txt_mlp.2.bias": "ff_b.2.bias",
"txt_mlp.2.weight": "ff_b.2.weight",
"txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "to_qkv_mlp.bias",
"linear1.weight": "to_qkv_mlp.weight",
"linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias",
"modulation.lin.weight": "norm.linear.weight",
"norm.key_norm.scale": "norm_k_a.weight",
"norm.query_norm.scale": "norm_q_a.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith("model.diffusion_model."):
name = name[len("model.diffusion_model."):]
names = name.split(".")
if name in rename_dict:
rename = rename_dict[name]
if name.startswith("final_layer.adaLN_modulation.1."):
param = torch.concat([param[3072:], param[:3072]], dim=0)
state_dict_[rename] = param
elif names[0] == "double_blocks":
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
elif names[0] == "single_blocks":
if ".".join(names[2:]) in suffix_rename_dict:
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
state_dict_[rename] = param
else:
pass
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
return state_dict_, {"disable_guidance_embedder": True}
else:
return state_dict_

View File

@@ -1,129 +0,0 @@
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class InfiniteYouImageProjector(nn.Module):
def __init__(
self,
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=8,
embedding_dim=512,
output_dim=4096,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
latents = latents.to(dtype=x.dtype, device=x.device)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
@staticmethod
def state_dict_converter():
return FluxInfiniteYouImageProjectorStateDictConverter()
class FluxInfiniteYouImageProjectorStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict['image_proj']

View File

@@ -1,25 +1,9 @@
from .general_modules import RMSNorm
from transformers import SiglipVisionModel, SiglipVisionConfig
from .svd_image_encoder import SVDImageEncoder
from .sd3_dit import RMSNorm
from transformers import CLIPImageProcessor
import torch
class SiglipVisionModelSO400M(SiglipVisionModel):
def __init__(self):
config = SiglipVisionConfig(
hidden_size=1152,
image_size=384,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
patch_size=14,
architectures=["SiglipModel"],
initializer_factor=1.0,
torch_dtype="float32",
transformers_version="4.37.0.dev0"
)
super().__init__(config)
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()

View File

@@ -1,521 +0,0 @@
import torch
from einops import rearrange
def low_version_attention(query, key, value, attn_bias=None):
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = torch.matmul(query, key.transpose(-2, -1))
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
return attn @ value
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size = q.shape[0]
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if qkv_preprocessor is not None:
q, k, v = qkv_preprocessor(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
if ipadapter_kwargs is not None:
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
if attn_mask is not None:
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
else:
import xformers.ops as xops
hidden_states = xops.memory_efficient_attention(q, k, v)
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
class CLIPEncoderLayer(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
super().__init__()
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
self.use_quick_gelu = use_quick_gelu
def quickGELU(self, x):
return x * torch.sigmoid(1.702 * x)
def forward(self, hidden_states, attn_mask=None):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.fc1(hidden_states)
if self.use_quick_gelu:
hidden_states = self.quickGELU(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SDTextEncoder(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=1):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
break
embeds = self.final_layer_norm(embeds)
return embeds
@staticmethod
def state_dict_converter():
return SDTextEncoderStateDictConverter()
class SDTextEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embeds",
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
}
attn_rename_dict = {
"self_attn.q_proj": "attn.to_q",
"self_attn.k_proj": "attn.to_k",
"self_attn.v_proj": "attn.to_v",
"self_attn.out_proj": "attn.to_out",
"layer_norm1": "layer_norm1",
"layer_norm2": "layer_norm2",
"mlp.fc1": "fc1",
"mlp.fc2": "fc2",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
elif name.startswith("text_model.encoder.layers."):
param = state_dict[name]
names = name.split(".")
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
param = param.reshape((1, param.shape[0], param.shape[1]))
state_dict_[rename_dict[name]] = param
return state_dict_
class LoRALayerBlock(torch.nn.Module):
def __init__(self, L, dim_in, dim_out):
super().__init__()
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, lora_A, lora_B):
x = self.x @ lora_A.T @ lora_B.T
x = self.layer_norm(x)
return x
class LoRAEmbedder(torch.nn.Module):
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1])
self.model_dict = torch.nn.ModuleDict(model_dict)
proj_dict = {}
for lora_pattern in lora_patterns:
layer_type, dim = lora_pattern["type"], lora_pattern["dim"]
if layer_type not in proj_dict:
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim)
self.proj_dict = torch.nn.ModuleDict(proj_dict)
self.lora_patterns = lora_patterns
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix],
"type": suffix,
})
return lora_patterns
def forward(self, lora):
lora_emb = []
for lora_pattern in self.lora_patterns:
name, layer_type = lora_pattern["name"], lora_pattern["type"]
lora_A = lora[name + ".lora_A.weight"]
lora_B = lora[name + ".lora_B.weight"]
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
lora_emb.append(lora_out)
lora_emb = torch.concat(lora_emb, dim=1)
return lora_emb
class FluxLoRAEncoder(torch.nn.Module):
def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1):
super().__init__()
self.num_embeds_per_lora = num_embeds_per_lora
# embedder
self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim)
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)])
# special embedding
self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim))
self.num_special_embeds = num_special_embeds
# final layer
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
self.final_linear = torch.nn.Linear(embed_dim, embed_dim)
def forward(self, lora):
lora_embeds = self.embedder(lora)
special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device)
embeds = torch.concat([special_embeds, lora_embeds], dim=1)
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds)
embeds = embeds[:, :self.num_special_embeds]
embeds = self.final_layer_norm(embeds)
embeds = self.final_linear(embeds)
return embeds
@staticmethod
def state_dict_converter():
return FluxLoRAEncoderStateDictConverter()
class FluxLoRAEncoderStateDictConverter:
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,306 +0,0 @@
import torch, math
from ..core.loader import load_state_dict
from typing import Union
class GeneralLoRALoader:
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
updated_num = 0
lora_name_dict = self.get_name_dict(state_dict_lora)
for name, module in model.named_modules():
if name in lora_name_dict:
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
state_dict = module.state_dict()
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
module.load_state_dict(state_dict)
updated_num += 1
print(f"{updated_num} tensors are updated by LoRA.")
class FluxLoRALoader(GeneralLoRALoader):
def __init__(self, device="cpu", torch_dtype=torch.float32):
super().__init__(device=device, torch_dtype=torch_dtype)
self.diffusers_rename_dict = {
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
}
self.civitai_rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
}
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
super().load(model, state_dict_lora, alpha)
def convert_state_dict(self,state_dict):
def guess_block_id(name,model_resource):
if model_resource == 'civitai':
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
if model_resource == 'diffusers':
names = name.split(".")
for i in names:
if i.isdigit():
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
return None, None
def guess_resource(state_dict):
for k in state_dict:
if "lora_unet_" in k:
return 'civitai'
elif k.startswith("transformer."):
return 'diffusers'
else:
None
model_resource = guess_resource(state_dict)
if model_resource is None:
return state_dict
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
def guess_alpha(state_dict):
for name, param in state_dict.items():
if ".alpha" in name:
for suffix in [".lora_down.weight", ".lora_A.weight"]:
name_ = name.replace(".alpha", suffix)
if name_ in state_dict:
lora_alpha = param.item() / state_dict[name_].shape[0]
lora_alpha = math.sqrt(lora_alpha)
return lora_alpha
return 1
alpha = guess_alpha(state_dict)
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name,model_resource)
if alpha != 1:
param *= alpha
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
if model_resource == 'diffusers':
for name in list(state_dict_.keys()):
if "single_blocks." in name and ".a_to_q." in name:
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
if mlp is None:
dim = 4
if 'lora_A' in name:
dim = 1
mlp = torch.zeros(dim * state_dict_[name].shape[0],
*state_dict_[name].shape[1:],
dtype=state_dict_[name].dtype)
else:
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
if 'lora_A' in name:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
elif 'lora_B' in name:
d, r = state_dict_[name].shape
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
param[:d, :r] = state_dict_.pop(name)
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
param[3*d:, 3*r:] = mlp
else:
param = torch.concat([
state_dict_.pop(name),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
mlp,
], dim=0)
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
state_dict_[name_] = param
for name in list(state_dict_.keys()):
for component in ["a", "b"]:
if f".{component}_to_q." in name:
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
concat_dim = 0
if 'lora_A' in name:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
elif 'lora_B' in name:
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
d, r = origin.shape
# print(d, r)
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
else:
param = torch.concat([
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
], dim=0)
state_dict_[name_] = param
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class FluxLoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)

View File

@@ -0,0 +1,32 @@
import torch
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
class FluxTextEncoder2(T5EncoderModel):
def __init__(self, config):
super().__init__(config)
self.eval()
def forward(self, input_ids):
outputs = super().forward(input_ids=input_ids)
prompt_emb = outputs.last_hidden_state
return prompt_emb
@staticmethod
def state_dict_converter():
return FluxTextEncoder2StateDictConverter()
class FluxTextEncoder2StateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = state_dict
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -1,112 +0,0 @@
import torch
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
class CLIPEncoderLayer(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
super().__init__()
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
self.use_quick_gelu = use_quick_gelu
def quickGELU(self, x):
return x * torch.sigmoid(1.702 * x)
def forward(self, hidden_states, attn_mask=None):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.fc1(hidden_states)
if self.use_quick_gelu:
hidden_states = self.quickGELU(hidden_states)
else:
hidden_states = torch.nn.functional.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class FluxTextEncoderClip(torch.nn.Module):
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
super().__init__()
# token_embedding
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
# position_embeds (This is a fixed tensor)
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
# encoders
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
# attn_mask
self.attn_mask = self.attention_mask(max_position_embeddings)
# final_layer_norm
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
def attention_mask(self, length):
mask = torch.empty(length, length)
mask.fill_(float("-inf"))
mask.triu_(1)
return mask
def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids)
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):
hidden_states = embeds
embeds = self.final_layer_norm(embeds)
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return pooled_embeds, hidden_states

View File

@@ -1,43 +0,0 @@
import torch
from transformers import T5EncoderModel, T5Config
class FluxTextEncoderT5(T5EncoderModel):
def __init__(self):
config = T5Config(**{
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"dtype": "bfloat16",
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": True,
"is_gated_act": True,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": True,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": False,
"transformers_version": "4.57.1",
"use_cache": True,
"vocab_size": 32128
})
super().__init__(config)
def forward(self, input_ids):
outputs = super().forward(input_ids=input_ids)
prompt_emb = outputs.last_hidden_state
return prompt_emb

View File

@@ -1,451 +1,303 @@
import torch
from einops import rearrange, repeat
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
class TileWorker:
class FluxVAEEncoder(SD3VAEEncoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEEncoderStateDictConverter()
class FluxVAEDecoder(SD3VAEDecoder):
def __init__(self):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@staticmethod
def state_dict_converter():
return FluxVAEDecoderStateDictConverter()
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
pass
def mask(self, height, width, border_width):
# Create a mask with shape (height, width).
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
x = torch.arange(height).repeat(width, 1).T
y = torch.arange(width).repeat(height, 1)
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
mask = (mask / border_width).clip(0, 1)
return mask
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
batch_size, channel, _, _ = model_input.shape
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
unfold_operator = torch.nn.Unfold(
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
model_input = unfold_operator(model_input)
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
return model_input
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
# Call y=forward_fn(x) for each tile
tile_num = model_input.shape[-1]
model_output_stack = []
for tile_id in range(0, tile_num, tile_batch_size):
# process input
tile_id_ = min(tile_id + tile_batch_size, tile_num)
x = model_input[:, :, :, :, tile_id: tile_id_]
x = x.to(device=inference_device, dtype=inference_dtype)
x = rearrange(x, "b c h w n -> (n b) c h w")
# process output
y = forward_fn(x)
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
y = y.to(device=tile_device, dtype=tile_dtype)
model_output_stack.append(y)
model_output = torch.concat(model_output_stack, dim=-1)
return model_output
def io_scale(self, model_output, tile_size):
# Determine the size modification happened in forward_fn
# We only consider the same scale on height and width.
io_scale = model_output.shape[2] / tile_size
return io_scale
def from_civitai(self, state_dict):
rename_dict = {
"encoder.conv_in.bias": "conv_in.bias",
"encoder.conv_in.weight": "conv_in.weight",
"encoder.conv_out.bias": "conv_out.bias",
"encoder.conv_out.weight": "conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
"encoder.norm_out.bias": "conv_norm_out.bias",
"encoder.norm_out.weight": "conv_norm_out.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
# The reversed function of tile
mask = self.mask(tile_size, tile_size, border_width)
mask = mask.to(device=tile_device, dtype=tile_dtype)
mask = rearrange(mask, "h w -> 1 1 h w 1")
model_output = model_output * mask
fold_operator = torch.nn.Fold(
output_size=(height, width),
kernel_size=(tile_size, tile_size),
stride=(tile_stride, tile_stride)
)
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
model_output = fold_operator(model_output) / fold_operator(mask)
return model_output
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
# Prepare
inference_device, inference_dtype = model_input.device, model_input.dtype
height, width = model_input.shape[2], model_input.shape[3]
border_width = int(tile_stride*0.5) if border_width is None else border_width
# tile
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
# inference
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
# resize
io_scale = self.io_scale(model_output, tile_size)
height, width = int(height*io_scale), int(width*io_scale)
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
border_width = int(border_width*io_scale)
# untile
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
# Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output
class ConvAttention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
q = self.to_q(conv_input)
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
k = self.to_k(conv_input)
v = self.to_v(conv_input)
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
hidden_states = self.to_out(conv_input)
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
return hidden_states
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
class VAEAttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
if use_conv_attention:
self.transformer_blocks = torch.nn.ModuleList([
ConvAttention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
else:
self.transformer_blocks = torch.nn.ModuleList([
Attention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
def forward(self, hidden_states, time_emb, text_emb, res_stack):
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class ResnetBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
super().__init__()
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = torch.nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
x = hidden_states
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if time_emb is not None:
emb = self.nonlinearity(time_emb)
emb = self.time_emb_proj(emb)[:, :, None, None]
x = x + emb
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.conv2(x)
if self.conv_shortcut is not None:
hidden_states = self.conv_shortcut(hidden_states)
hidden_states = hidden_states + x
return hidden_states, time_emb, text_emb, res_stack
class UpSampler(torch.nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = self.conv(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
class DownSampler(torch.nn.Module):
def __init__(self, channels, padding=1, extra_padding=False):
super().__init__()
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
self.extra_padding = extra_padding
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
if self.extra_padding:
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states, time_emb, text_emb, res_stack
class FluxVAEDecoder(torch.nn.Module):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
self.blocks = torch.nn.ModuleList([
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
UpSampler(512),
# UpDecoderBlock2D
ResnetBlock(512, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
UpSampler(256),
# UpDecoderBlock2D
ResnetBlock(256, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = sample / self.scaling_factor + self.shift_factor
hidden_states = self.conv_in(hidden_states)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class FluxVAEEncoder(torch.nn.Module):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
self.blocks = torch.nn.ModuleList([
# DownEncoderBlock2D
ResnetBlock(128, 128, eps=1e-6),
ResnetBlock(128, 128, eps=1e-6),
DownSampler(128, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(128, 256, eps=1e-6),
ResnetBlock(256, 256, eps=1e-6),
DownSampler(256, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(256, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
DownSampler(512, padding=0, extra_padding=True),
# DownEncoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
ResnetBlock(512, 512, eps=1e-6),
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
])
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
hidden_states = TileWorker().tiled_forward(
lambda x: self.forward(x),
sample,
tile_size,
tile_stride,
tile_device=sample.device,
tile_dtype=sample.dtype
)
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
# 1. pre-process
hidden_states = self.conv_in(sample)
time_emb = None
text_emb = None
res_stack = None
# 2. blocks
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 3. output
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states[:, :16]
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
return hidden_states
def encode_video(self, sample, batch_size=8):
B = sample.shape[0]
hidden_states = []
for i in range(0, sample.shape[2], batch_size):
j = min(i + batch_size, sample.shape[2])
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
hidden_states_batch = self(sample_batch)
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
hidden_states.append(hidden_states_batch)
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
pass
def from_civitai(self, state_dict):
rename_dict = {
"decoder.conv_in.bias": "conv_in.bias",
"decoder.conv_in.weight": "conv_in.weight",
"decoder.conv_out.bias": "conv_out.bias",
"decoder.conv_out.weight": "conv_out.weight",
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
"decoder.norm_out.bias": "conv_norm_out.bias",
"decoder.norm_out.weight": "conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
}
state_dict_ = {}
for name in state_dict:
if name in rename_dict:
param = state_dict[name]
if "transformer_blocks" in rename_dict[name]:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_

View File

@@ -1,56 +0,0 @@
import torch
from .general_modules import TemporalTimesteps
class MultiValueEncoder(torch.nn.Module):
def __init__(self, encoders=()):
super().__init__()
if not isinstance(encoders, list):
encoders = [encoders]
self.encoders = torch.nn.ModuleList(encoders)
def __call__(self, values, dtype):
emb = []
for encoder, value in zip(self.encoders, values):
if value is not None:
value = value.unsqueeze(0)
emb.append(encoder(value, dtype))
emb = torch.concat(emb, dim=0)
return emb
class SingleValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
super().__init__()
self.prefer_len = prefer_len
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
self.prefer_value_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.positional_embedding = torch.nn.Parameter(
torch.randn(self.prefer_len, dim_out)
)
def forward(self, value, dtype):
value = value * 1000
emb = self.prefer_proj(value).to(dtype)
emb = self.prefer_value_embedder(emb).squeeze(0)
base_embeddings = emb.expand(self.prefer_len, -1)
positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
learned_embeddings = base_embeddings + positional_embedding
return learned_embeddings
@staticmethod
def state_dict_converter():
return SingleValueEncoderStateDictConverter()
class SingleValueEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,146 +0,0 @@
import torch, math
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
computation_device = None,
align_dtype_to_timestep = False,
):
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
if align_dtype_to_timestep:
emb = emb.to(timesteps.dtype)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TemporalTimesteps(torch.nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.computation_device = computation_device
self.scale = scale
self.align_dtype_to_timestep = align_dtype_to_timestep
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
computation_device=self.computation_device,
scale=self.scale,
align_dtype_to_timestep=self.align_dtype_to_timestep,
)
return t_emb
class DiffusersCompatibleTimestepProj(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
self.act = torch.nn.SiLU()
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
def forward(self, x):
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x
class TimestepEmbeddings(torch.nn.Module):
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False):
super().__init__()
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
if diffusers_compatible_format:
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
else:
self.timestep_embedder = torch.nn.Sequential(
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
)
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = torch.nn.Embedding(2, dim_out)
def forward(self, timestep, dtype, addition_t_cond=None):
time_emb = self.time_proj(timestep).to(dtype)
time_emb = self.timestep_embedder(time_emb)
if addition_t_cond is not None:
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=dtype)
time_emb = time_emb + addition_t_emb
return time_emb
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps, elementwise_affine=True):
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.ones((dim,)))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
return hidden_states
class AdaLayerNorm(torch.nn.Module):
def __init__(self, dim, single=False, dual=False):
super().__init__()
self.single = single
self.dual = dual
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
x = self.norm(x) * (1 + scale) + shift
return x
elif self.dual:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
norm_x = self.norm(x)
x = norm_x * (1 + scale_msa) + shift_msa
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
x = self.norm(x) * (1 + scale_msa) + shift_msa
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp

View File

@@ -0,0 +1,451 @@
from .attention import Attention
from einops import repeat, rearrange
import math
import torch
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
super().__init__()
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
self.rotary_emb_on_k = rotary_emb_on_k
self.k_cache, self.v_cache = [], []
def reshape_for_broadcast(self, freqs_cis, x):
ndim = x.ndim
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(self, x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(self, xq, xk, freqs_cis):
xk_out = None
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
# norm
q = self.q_norm(q)
k = self.k_norm(k)
# RoPE
if self.rotary_emb_on_k:
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
else:
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
if to_cache:
self.k_cache.append(k)
self.v_cache.append(v)
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
k = torch.concat([k] + self.k_cache, dim=2)
v = torch.concat([v] + self.v_cache, dim=2)
self.k_cache, self.v_cache = [], []
return q, k, v
class FP32_Layernorm(torch.nn.LayerNorm):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
class FP32_SiLU(torch.nn.SiLU):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
class HunyuanDiTFinalLayer(torch.nn.Module):
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
super().__init__()
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = torch.nn.Sequential(
FP32_SiLU(),
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
)
def modulate(self, x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def forward(self, hidden_states, condition_emb):
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
hidden_states = self.linear(hidden_states)
return hidden_states
class HunyuanDiTBlock(torch.nn.Module):
def __init__(
self,
hidden_dim=1408,
condition_dim=1408,
num_heads=16,
mlp_ratio=4.3637,
text_dim=1024,
skip_connection=False
):
super().__init__()
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
)
if skip_connection:
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
else:
self.skip_norm, self.skip_linear = None, None
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
# Long Skip Connection
if self.skip_norm is not None and self.skip_linear is not None:
hidden_states = torch.cat([hidden_states, residual], dim=-1)
hidden_states = self.skip_norm(hidden_states)
hidden_states = self.skip_linear(hidden_states)
# Self-Attention
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
attn_input = self.norm1(hidden_states) + shift_msa
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
# Cross-Attention
attn_input = self.norm3(hidden_states)
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
# FFN Layer
mlp_input = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(mlp_input)
return hidden_states
class AttentionPool(torch.nn.Module):
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
super().__init__()
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = torch.nn.functional.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class PatchEmbed(torch.nn.Module):
def __init__(
self,
patch_size=(2, 2),
in_chans=4,
embed_dim=1408,
bias=True,
):
super().__init__()
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(t, "b -> b d", d=dim)
return embedding
class TimestepEmbedder(torch.nn.Module):
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class HunyuanDiT(torch.nn.Module):
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
super().__init__()
# Embedders
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
self.t5_embedder = torch.nn.Sequential(
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
FP32_SiLU(),
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
)
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
self.patch_embedder = PatchEmbed(in_chans=in_channels)
self.timestep_embedder = TimestepEmbedder()
self.extra_embedder = torch.nn.Sequential(
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
FP32_SiLU(),
torch.nn.Linear(hidden_dim * 4, hidden_dim),
)
# Transformer blocks
self.num_layers_down = num_layers_down
self.num_layers_up = num_layers_up
self.blocks = torch.nn.ModuleList(
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
)
# Output layers
self.final_layer = HunyuanDiTFinalLayer()
self.out_channels = out_channels
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
text_emb_mask = text_emb_mask.bool()
text_emb_mask_t5 = text_emb_mask_t5.bool()
text_emb_t5 = self.t5_embedder(text_emb_t5)
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
return text_emb
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
# Text embedding
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
# Timestep embedding
timestep_emb = self.timestep_embedder(timestep)
# Size embedding
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
size_emb = size_emb.view(-1, 6 * 256)
# Style embedding
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
# Concatenate all extra vectors
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
return condition_emb
def unpatchify(self, x, h, w):
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
def build_mask(self, data, is_bound):
_, _, H, W = data.shape
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else h + 1,
pad if is_bound[1] else H - h,
pad if is_bound[2] else w + 1,
pad if is_bound[3] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
mask = rearrange(mask, "H W -> 1 H W")
return mask
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
B, C, H, W = hidden_states.shape
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = H - tile_size, H
if w_ > W: w, w_ = W - tile_size, W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
if residual is not None:
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
else:
residual_batch = None
# Forward
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
weight[:, :, hl:hr, wl:wr] += mask
values /= weight
return values
def forward(
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
tiled=False, tile_size=64, tile_stride=32,
to_cache=False,
use_gradient_checkpointing=False,
):
# Embeddings
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
# Input
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
hidden_states = self.patch_embedder(hidden_states)
# Blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tiled:
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
hidden_states = self.tiled_block_forward(
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
tile_size=tile_size, tile_stride=tile_stride
)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
else:
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
if self.training and use_gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
use_reentrant=False,
)
else:
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
# Output
hidden_states = self.final_layer(hidden_states, condition_emb)
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
hidden_states, _ = hidden_states.chunk(2, dim=1)
return hidden_states
@staticmethod
def state_dict_converter():
return HunyuanDiTStateDictConverter()
class HunyuanDiTStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name, param in state_dict.items():
name_ = name
name_ = name_.replace(".default_modulation.", ".modulation.")
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
name_ = name_.replace(".q_proj.", ".to_q.")
name_ = name_.replace(".out_proj.", ".to_out.")
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
name_ = name_.replace("pooler.", "t5_pooler.")
name_ = name_.replace("x_embedder.", "patch_embedder.")
name_ = name_.replace("t_embedder.", "timestep_embedder.")
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
name_ = name_.replace("style_embedder.weight", "style_embedder")
if ".kv_proj." in name_:
param_k = param[:param.shape[0]//2]
param_v = param[param.shape[0]//2:]
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
elif ".Wqkv." in name_:
param_q = param[:param.shape[0]//3]
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
param_v = param[param.shape[0]//3*2:]
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
elif "style_embedder" in name_:
state_dict_[name_] = param.squeeze()
else:
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,163 @@
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
import torch
class HunyuanDiTCLIPTextEncoder(BertModel):
def __init__(self):
config = BertConfig(
_name_or_path = "",
architectures = ["BertModel"],
attention_probs_dropout_prob = 0.1,
bos_token_id = 0,
classifier_dropout = None,
directionality = "bidi",
eos_token_id = 2,
hidden_act = "gelu",
hidden_dropout_prob = 0.1,
hidden_size = 1024,
initializer_range = 0.02,
intermediate_size = 4096,
layer_norm_eps = 1e-12,
max_position_embeddings = 512,
model_type = "bert",
num_attention_heads = 16,
num_hidden_layers = 24,
output_past = True,
pad_token_id = 0,
pooler_fc_size = 768,
pooler_num_attention_heads = 12,
pooler_num_fc_layers = 3,
pooler_size_per_head = 128,
pooler_type = "first_token_transform",
position_embedding_type = "absolute",
torch_dtype = "float32",
transformers_version = "4.37.2",
type_vocab_size = 2,
use_cache = True,
vocab_size = 47020
)
super().__init__(config, add_pooling_layer=False)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
past_key_values_length = 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
all_hidden_states = encoder_outputs.hidden_states
prompt_emb = all_hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
@staticmethod
def state_dict_converter():
return HunyuanDiTCLIPTextEncoderStateDictConverter()
class HunyuanDiTT5TextEncoder(T5EncoderModel):
def __init__(self):
config = T5Config(
_name_or_path = "../HunyuanDiT/t2i/mt5",
architectures = ["MT5ForConditionalGeneration"],
classifier_dropout = 0.0,
d_ff = 5120,
d_kv = 64,
d_model = 2048,
decoder_start_token_id = 0,
dense_act_fn = "gelu_new",
dropout_rate = 0.1,
eos_token_id = 1,
feed_forward_proj = "gated-gelu",
initializer_factor = 1.0,
is_encoder_decoder = True,
is_gated_act = True,
layer_norm_epsilon = 1e-06,
model_type = "t5",
num_decoder_layers = 24,
num_heads = 32,
num_layers = 24,
output_past = True,
pad_token_id = 0,
relative_attention_max_distance = 128,
relative_attention_num_buckets = 32,
tie_word_embeddings = False,
tokenizer_class = "T5Tokenizer",
transformers_version = "4.37.2",
use_cache = True,
vocab_size = 250112
)
super().__init__(config)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
prompt_emb = outputs.hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
@staticmethod
def state_dict_converter():
return HunyuanDiTT5TextEncoderStateDictConverter()
class HunyuanDiTCLIPTextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class HunyuanDiTT5TextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
state_dict_["shared.weight"] = state_dict["shared.weight"]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,885 @@
import torch
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .utils import init_weights_on_device
from einops import rearrange, repeat
from tqdm import tqdm
from typing import Union, Tuple, List
def HunyuanVideoRope(latents):
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
[16, 56, 56],
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
theta=256,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
class PatchEmbed(torch.nn.Module):
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
super().__init__()
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class IndividualTokenRefinerBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, num_heads=24):
super().__init__()
self.num_heads = num_heads
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * 4),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size * 4, hidden_size)
)
self.adaLN_modulation = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
)
def forward(self, x, c, attn_mask=None):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = rearrange(attn, "B H L D -> B L (H D)")
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
return x
class SingleTokenRefiner(torch.nn.Module):
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
super().__init__()
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.c_embedder = torch.nn.Sequential(
torch.nn.Linear(in_channels, hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
def forward(self, x, t, mask=None):
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
mask_float = mask.float().unsqueeze(-1)
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
mask = mask.to(device=x.device, dtype=torch.bool)
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
mask = mask & mask.transpose(2, 3)
mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, mask)
return x
class ModulateDiT(torch.nn.Module):
def __init__(self, hidden_size, factor=6):
super().__init__()
self.act = torch.nn.SiLU()
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
def forward(self, x):
return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def reshape_for_broadcast(
freqs_cis,
x: torch.Tensor,
head_first=False,
):
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = (
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis,
head_first: bool = False,
):
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], -1, 2)
) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
xq.device
) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], -1, 2)
) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
def attention(q, k, v):
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2).flatten(2, 3)
return x
class MMDoubleStreamBlockComponent(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.heads_num = heads_num
self.mod = ModulateDiT(hidden_size)
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
)
def forward(self, hidden_states, conditioning, freqs_cis=None):
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.norm_q(q)
k = self.norm_k(k)
if freqs_cis is not None:
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
def process_ff(self, hidden_states, attn_output, mod):
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
return hidden_states
class MMDoubleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
return hidden_states_a, hidden_states_b
class MMSingleStreamBlockOriginal(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.hidden_size = hidden_size
self.heads_num = heads_num
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = torch.nn.GELU(approximate="tanh")
self.modulation = ModulateDiT(hidden_size, factor=3)
def forward(self, x, vec, freqs_cis=None, txt_len=256):
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.q_norm(q)
k = self.k_norm(k)
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q = torch.cat((q_a, q_b), dim=1)
k = torch.cat((k_a, k_b), dim=1)
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
return x + output * mod_gate.unsqueeze(1)
class MMSingleStreamBlock(torch.nn.Module):
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
super().__init__()
self.heads_num = heads_num
self.mod = ModulateDiT(hidden_size, factor=3)
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
self.ff = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
)
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
qkv = self.to_qkv(norm_hidden_states)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
q = self.norm_q(q)
k = self.norm_k(k)
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
attn_output_a = attention(q_a, k_a, v_a)
attn_output_b = attention(q_b, k_b, v_b)
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
return hidden_states
class FinalLayer(torch.nn.Module):
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
super().__init__()
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
class HunyuanVideoDiT(torch.nn.Module):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
super().__init__()
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.vector_in = torch.nn.Sequential(
torch.nn.Linear(768, hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size)
)
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
self.final_layer = FinalLayer(hidden_size)
# TODO: remove these parameters
self.dtype = torch.bfloat16
self.patch_size = [1, 2, 2]
self.hidden_size = 3072
self.heads_num = 24
self.rope_dim_list = [16, 56, 56]
def unpatchify(self, x, T, H, W):
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
return x
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
self.warm_device = warm_device
self.cold_device = cold_device
self.to(self.cold_device)
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
for model_name in loadmodel_names:
model = getattr(self, model_name)
if model is not None:
model.to(device)
torch.cuda.empty_cache()
def prepare_freqs(self, latents):
return HunyuanVideoRope(latents)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
**kwargs
):
B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask)
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
x = torch.concat([img, txt], dim=1)
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin))
img = x[:, :-256]
img = self.final_layer(img, vec)
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
def cast_to(weight, dtype=None, device=None, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
def cast_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
return weight
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
weight = cast_to(s.weight, dtype, device)
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
return weight, bias
class quantized_layer:
class Linear(torch.nn.Linear):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def block_forward_(self, x, i, j, dtype, device):
weight_ = cast_to(
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
dtype=dtype, device=device
)
if self.bias is None or i > 0:
bias_ = None
else:
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
y_ = torch.nn.functional.linear(x_, weight_, bias_)
del x_, weight_, bias_
torch.cuda.empty_cache()
return y_
def block_forward(self, x, **kwargs):
# This feature can only reduce 2GB VRAM, so we disable it.
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
for i in range((self.in_features + self.block_size - 1) // self.block_size):
for j in range((self.out_features + self.block_size - 1) // self.block_size):
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
return y
def forward(self, x, **kwargs):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.linear(x, weight, bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
super().__init__()
self.module = module
self.dtype = dtype
self.device = device
def forward(self, hidden_states, **kwargs):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype)
if self.module.weight is not None:
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
hidden_states = hidden_states * weight
return hidden_states
class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
super().__init__(*args, **kwargs)
self.dtype = dtype
self.device = device
def forward(self, x):
if self.weight is not None and self.bias is not None:
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
with init_weights_on_device():
new_layer = quantized_layer.Linear(
module.in_features, module.out_features, bias=module.bias is not None,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.Conv3d):
with init_weights_on_device():
new_layer = quantized_layer.Conv3d(
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
new_layer = quantized_layer.RMSNorm(
module,
dtype=dtype, device=device
)
setattr(model, name, new_layer)
elif isinstance(module, torch.nn.LayerNorm):
with init_weights_on_device():
new_layer = quantized_layer.LayerNorm(
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
dtype=dtype, device=device
)
new_layer.load_state_dict(module.state_dict(), assign=True)
setattr(model, name, new_layer)
else:
replace_layer(module, dtype=dtype, device=device)
replace_layer(self, dtype=dtype, device=device)
@staticmethod
def state_dict_converter():
return HunyuanVideoDiTStateDictConverter()
class HunyuanVideoDiTStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
if "module" in state_dict:
state_dict = state_dict["module"]
direct_dict = {
"img_in.proj": "img_in.proj",
"time_in.mlp.0": "time_in.timestep_embedder.0",
"time_in.mlp.2": "time_in.timestep_embedder.2",
"vector_in.in_layer": "vector_in.0",
"vector_in.out_layer": "vector_in.2",
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
"txt_in.input_embedder": "txt_in.input_embedder",
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
"final_layer.linear": "final_layer.linear",
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
}
txt_suffix_dict = {
"norm1": "norm1",
"self_attn_qkv": "self_attn_qkv",
"self_attn_proj": "self_attn_proj",
"norm2": "norm2",
"mlp.fc1": "mlp.0",
"mlp.fc2": "mlp.2",
"adaLN_modulation.1": "adaLN_modulation.1",
}
double_suffix_dict = {
"img_mod.linear": "component_a.mod.linear",
"img_attn_qkv": "component_a.to_qkv",
"img_attn_q_norm": "component_a.norm_q",
"img_attn_k_norm": "component_a.norm_k",
"img_attn_proj": "component_a.to_out",
"img_mlp.fc1": "component_a.ff.0",
"img_mlp.fc2": "component_a.ff.2",
"txt_mod.linear": "component_b.mod.linear",
"txt_attn_qkv": "component_b.to_qkv",
"txt_attn_q_norm": "component_b.norm_q",
"txt_attn_k_norm": "component_b.norm_k",
"txt_attn_proj": "component_b.to_out",
"txt_mlp.fc1": "component_b.ff.0",
"txt_mlp.fc2": "component_b.ff.2",
}
single_suffix_dict = {
"linear1": ["to_qkv", "ff.0"],
"linear2": ["to_out", "ff.2"],
"q_norm": "norm_q",
"k_norm": "norm_k",
"modulation.linear": "mod.linear",
}
# single_suffix_dict = {
# "linear1": "linear1",
# "linear2": "linear2",
# "q_norm": "q_norm",
# "k_norm": "k_norm",
# "modulation.linear": "modulation.linear",
# }
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
direct_name = ".".join(names[:-1])
if direct_name in direct_dict:
name_ = direct_dict[direct_name] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "double_blocks":
prefix = ".".join(names[:2])
suffix = ".".join(names[2:-1])
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "single_blocks":
prefix = ".".join(names[:2])
suffix = ".".join(names[2:-1])
if isinstance(single_suffix_dict[suffix], list):
if suffix == "linear1":
name_a, name_b = single_suffix_dict[suffix]
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
elif suffix == "linear2":
if names[-1] == "weight":
name_a, name_b = single_suffix_dict[suffix]
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
else:
name_a, name_b = single_suffix_dict[suffix]
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
else:
pass
else:
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
elif names[0] == "txt_in":
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
suffix = ".".join(names[4:-1])
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
state_dict_[name_] = param
else:
pass
return state_dict_

View File

@@ -0,0 +1,55 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache
from copy import deepcopy
import torch
class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
def forward(
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids)
past_key_values = DynamicCache()
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
position_embeddings = rotary_emb(hidden_states, position_ids)
# decoder layers
for layer_id, decoder_layer in enumerate(self.layers):
if self.auto_offload:
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
break
return hidden_states

View File

@@ -0,0 +1,507 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from tqdm import tqdm
from einops import repeat
class CausalConv3d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
super().__init__()
self.pad_mode = pad_mode
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
) # W, H, T
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class UpsampleCausal3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.upsample_factor = upsample_factor
self.conv = None
if use_conv:
kernel_size = 3 if kernel_size is None else kernel_size
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
def forward(self, hidden_states):
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# interpolate
B, C, T, H, W = hidden_states.shape
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
if T > 1:
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
if self.conv:
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlockCausal3D(nn.Module):
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
super().__init__()
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
self.dropout = nn.Dropout(dropout)
self.nonlinearity = nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
def forward(self, input_tensor):
hidden_states = input_tensor
# conv1
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
# conv2
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
# shortcut
if self.conv_shortcut is not None:
input_tensor = (self.conv_shortcut(input_tensor))
# shortcut and scale
output_tensor = input_tensor + hidden_states
return output_tensor
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // n_hw
mask[i, :(i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
class Attention(nn.Module):
def __init__(self,
in_channels,
num_heads,
head_dim,
num_groups=32,
dropout=0.0,
eps=1e-6,
bias=True,
residual_connection=True):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.residual_connection = residual_connection
dim_inner = head_dim * num_heads
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
def forward(self, input_tensor, attn_mask=None):
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
batch_size = hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(hidden_states)
v = self.to_v(hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if attn_mask is not None:
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = self.to_out(hidden_states)
if self.residual_connection:
output_tensor = input_tensor + hidden_states
return output_tensor
class UNetMidBlockCausal3D(nn.Module):
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
super().__init__()
resnets = [
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
groups=num_groups,
eps=eps,
)
]
attentions = []
attention_head_dim = attention_head_dim or in_channels
for _ in range(num_layers):
attentions.append(
Attention(
in_channels,
num_heads=in_channels // attention_head_dim,
head_dim=attention_head_dim,
num_groups=num_groups,
dropout=dropout,
eps=eps,
bias=True,
residual_connection=True,
))
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
groups=num_groups,
eps=eps,
))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states):
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
B, C, T, H, W = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
hidden_states = attn(hidden_states, attn_mask=attn_mask)
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
hidden_states = resnet(hidden_states)
return hidden_states
class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
eps=1e-6,
num_groups=32,
add_upsample=True,
upsample_scale_factor=(2, 2, 2),
):
super().__init__()
resnets = []
for i in range(num_layers):
cur_in_channel = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=cur_in_channel,
out_channels=out_channels,
groups=num_groups,
dropout=dropout,
eps=eps,
))
self.resnets = nn.ModuleList(resnets)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([
UpsampleCausal3D(
out_channels,
use_conv=True,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
])
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class DecoderCausal3D(nn.Module):
def __init__(
self,
in_channels=16,
out_channels=3,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
dropout=dropout,
eps=eps,
num_groups=num_groups,
attention_head_dim=block_out_channels[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block = UpDecoderBlockCausal3D(
in_channels=prev_output_channel,
out_channels=output_channel,
dropout=dropout,
num_layers=layers_per_block + 1,
eps=eps,
num_groups=num_groups,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# middle
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states,
use_reentrant=False,
)
else:
# middle
hidden_states = self.mid_block(hidden_states)
# up
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoVAEDecoder(nn.Module):
def __init__(
self,
in_channels=16,
out_channels=3,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.decoder = DecoderCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=eps,
dropout=dropout,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
num_groups=num_groups,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
gradient_checkpointing=gradient_checkpointing,
)
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.scaling_factor = 0.476986
def forward(self, latents):
latents = latents / self.scaling_factor
latents = self.post_quant_conv(latents)
dec = self.decoder(latents)
return dec
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, H, W = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
mask = torch.stack([t, h, w]).min(dim=0).values
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tile_forward(self, hidden_states, tile_size, tile_stride):
B, C, T, H, W = hidden_states.shape
size_t, size_h, size_w = tile_size
stride_t, stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for t in range(0, T, stride_t):
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
t_, h_, w_ = t + size_t, h + size_h, w + size_w
tasks.append((t, t_, h, h_, w, w_))
# Run
torch_dtype = self.post_quant_conv.weight.dtype
data_device = hidden_states.device
computation_device = self.post_quant_conv.weight.device
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
if t > 0:
hidden_states_batch = hidden_states_batch[:, :, 1:]
mask = self.build_mask(
hidden_states_batch,
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
).to(dtype=torch_dtype, device=data_device)
target_t = 0 if t==0 else t * 4 + 1
target_h = h * 8
target_w = w * 8
values[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
return values / weight
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
latents = latents.to(self.post_quant_conv.weight.dtype)
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
@staticmethod
def state_dict_converter():
return HunyuanVideoVAEDecoderStateDictConverter()
class HunyuanVideoVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
state_dict_[name] = state_dict[name]
return state_dict_

View File

@@ -0,0 +1,307 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from tqdm import tqdm
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
class DownsampleCausal3D(nn.Module):
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
super().__init__()
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
return hidden_states
class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
eps=1e-6,
num_groups=32,
add_downsample=True,
downsample_stride=2,
):
super().__init__()
resnets = []
for i in range(num_layers):
cur_in_channel = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=cur_in_channel,
out_channels=out_channels,
groups=num_groups,
dropout=dropout,
eps=eps,
))
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if add_downsample:
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
out_channels,
out_channels,
stride=downsample_stride,
)])
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class EncoderCausal3D(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 16,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
gradient_checkpointing=False,
):
super().__init__()
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = DownEncoderBlockCausal3D(
in_channels=input_channel,
out_channels=output_channel,
dropout=dropout,
num_layers=layers_per_block,
eps=eps,
num_groups=num_groups,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
dropout=dropout,
eps=eps,
num_groups=num_groups,
attention_head_dim=block_out_channels[-1],
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
for down_block in self.down_blocks:
torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states,
use_reentrant=False,
)
# middle
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states,
use_reentrant=False,
)
else:
# down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
# middle
hidden_states = self.mid_block(hidden_states)
# post-process
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoVAEEncoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=16,
eps=1e-6,
dropout=0.0,
block_out_channels=[128, 256, 512, 512],
layers_per_block=2,
num_groups=32,
time_compression_ratio=4,
spatial_compression_ratio=8,
gradient_checkpointing=False,
):
super().__init__()
self.encoder = EncoderCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=eps,
dropout=dropout,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
num_groups=num_groups,
time_compression_ratio=time_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
gradient_checkpointing=gradient_checkpointing,
)
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
self.scaling_factor = 0.476986
def forward(self, images):
latents = self.encoder(images)
latents = self.quant_conv(latents)
latents = latents[:, :16]
latents = latents * self.scaling_factor
return latents
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, T, H, W = data.shape
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
mask = torch.stack([t, h, w]).min(dim=0).values
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tile_forward(self, hidden_states, tile_size, tile_stride):
B, C, T, H, W = hidden_states.shape
size_t, size_h, size_w = tile_size
stride_t, stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for t in range(0, T, stride_t):
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
t_, h_, w_ = t + size_t, h + size_h, w + size_w
tasks.append((t, t_, h, h_, w, w_))
# Run
torch_dtype = self.quant_conv.weight.dtype
data_device = hidden_states.device
computation_device = self.quant_conv.weight.device
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
if t > 0:
hidden_states_batch = hidden_states_batch[:, :, 1:]
mask = self.build_mask(
hidden_states_batch,
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
).to(dtype=torch_dtype, device=data_device)
target_t = 0 if t==0 else t // 4 + 1
target_h = h // 8
target_w = w // 8
values[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
return values / weight
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
latents = latents.to(self.quant_conv.weight.dtype)
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
@staticmethod
def state_dict_converter():
return HunyuanVideoVAEEncoderStateDictConverter()
class HunyuanVideoVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict:
if name.startswith('encoder.') or name.startswith('quant_conv.'):
state_dict_[name] = state_dict[name]
return state_dict_

File diff suppressed because one or more lines are too long

View File

@@ -1,902 +0,0 @@
from typing import List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.amp as amp
import numpy as np
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward
class RMSNorm_FP32(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class RotaryPositionalEmbedding(nn.Module):
def __init__(self,
head_dim,
cp_split_hw=None
):
"""Rotary positional embedding for 3D
Reference : https://blog.eleuther.ai/rotary-embeddings/
Paper: https://arxiv.org/pdf/2104.09864.pdf
Args:
dim: Dimension of embedding
base: Base value for exponential
"""
super().__init__()
self.head_dim = head_dim
assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.'
self.cp_split_hw = cp_split_hw
# We take the assumption that the longest side of grid will not larger than 512, i.e, 512 * 8 = 4098 input pixels
self.base = 10000
self.freqs_dict = {}
def register_grid_size(self, grid_size):
if grid_size not in self.freqs_dict:
self.freqs_dict.update({
grid_size: self.precompute_freqs_cis_3d(grid_size)
})
def precompute_freqs_cis_3d(self, grid_size):
num_frames, height, width = grid_size
dim_t = self.head_dim - 4 * (self.head_dim // 6)
dim_h = 2 * (self.head_dim // 6)
dim_w = 2 * (self.head_dim // 6)
freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32)
grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32)
grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32)
grid_t = torch.from_numpy(grid_t).float()
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
# (T H W D)
freqs = rearrange(freqs, "T H W D -> (T H W) D")
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# with torch.no_grad():
# freqs = rearrange(freqs, "(T H W) D -> T H W D", T=num_frames, H=height, W=width)
# freqs = context_parallel_util.split_cp_2d(freqs, seq_dim_hw=(1, 2), split_hw=self.cp_split_hw)
# freqs = rearrange(freqs, "T H W D -> (T H W) D")
return freqs
def forward(self, q, k, grid_size):
"""3D RoPE.
Args:
query: [B, head, seq, head_dim]
key: [B, head, seq, head_dim]
Returns:
query and key with the same shape as input.
"""
if grid_size not in self.freqs_dict:
self.register_grid_size(grid_size)
freqs_cis = self.freqs_dict[grid_size].to(q.device)
q_, k_ = q.float(), k.float()
freqs_cis = freqs_cis.float().to(q.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
q_ = (q_ * cos) + (rotate_half(q_) * sin)
k_ = (k_ * cos) + (rotate_half(k_) * sin)
return q_.type_as(q), k_.type_as(k)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = None,
cp_split_hw: Optional[List[int]] = None
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
self.enable_bsa = enable_bsa
self.bsa_params = bsa_params
self.cp_split_hw = cp_split_hw
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.proj = nn.Linear(dim, dim)
self.rope_3d = RotaryPositionalEmbedding(
self.head_dim,
cp_split_hw=cp_split_hw
)
def _process_attn(self, q, k, v, shape):
q = rearrange(q, "B H S D -> B S (H D)")
k = rearrange(k, "B H S D -> B S (H D)")
v = rearrange(v, "B H S D -> B S (H D)")
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads)
return x
def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if return_kv:
k_cache, v_cache = k.clone(), v.clone()
q, k = self.rope_3d(q, k, shape)
# cond mode
if num_cond_latents is not None and num_cond_latents > 0:
num_cond_latents_thw = num_cond_latents * (N // shape[0])
# process the condition tokens
q_cond = q[:, :, :num_cond_latents_thw].contiguous()
k_cond = k[:, :, :num_cond_latents_thw].contiguous()
v_cond = v[:, :, :num_cond_latents_thw].contiguous()
x_cond = self._process_attn(q_cond, k_cond, v_cond, shape)
# process the noise tokens
q_noise = q[:, :, num_cond_latents_thw:].contiguous()
x_noise = self._process_attn(q_noise, k, v, shape)
# merge x_cond and x_noise
x = torch.cat([x_cond, x_noise], dim=2).contiguous()
else:
x = self._process_attn(q, k, v, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
if return_kv:
return x, (k_cache, v_cache)
else:
return x
def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor:
"""
"""
B, N, C = x.shape
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) # [3, B, H, N, D]
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
T, H, W = shape
k_cache, v_cache = kv_cache
assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B]
if k_cache.shape[0] == 1:
k_cache = k_cache.repeat(B, 1, 1, 1)
v_cache = v_cache.repeat(B, 1, 1, 1)
if num_cond_latents is not None and num_cond_latents > 0:
k_full = torch.cat([k_cache, k], dim=2).contiguous()
v_full = torch.cat([v_cache, v], dim=2).contiguous()
q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous()
q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W))
q = q_padding[:, :, -N:].contiguous()
x = self._process_attn(q, k_full, v_full, shape)
x_output_shape = (B, N, C)
x = x.transpose(1, 2) # [B, H, N, D] --> [B, N, H, D]
x = x.reshape(x_output_shape) # [B, N, H, D] --> [B, N, C]
x = self.proj(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(
self,
dim,
num_heads,
enable_flashattn3=False,
enable_flashattn2=False,
enable_xformers=False,
):
super(MultiHeadCrossAttention, self).__init__()
assert dim % num_heads == 0, "d_model must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = nn.Linear(dim, dim)
self.kv_linear = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6)
self.enable_flashattn3 = enable_flashattn3
self.enable_flashattn2 = enable_flashattn2
self.enable_xformers = enable_xformers
def _process_cross_attn(self, x, cond, kv_seqlen):
B, N, C = x.shape
assert C == self.dim and cond.shape[2] == self.dim
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
q = rearrange(q, "B S H D -> B S (H D)")
k = rearrange(k, "B S H D -> B S (H D)")
v = rearrange(v, "B S H D -> B S (H D)")
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = x.view(B, -1, C)
x = self.proj(x)
return x
def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None):
"""
x: [B, N, C]
cond: [B, M, C]
"""
if num_cond_latents is None or num_cond_latents == 0:
return self._process_cross_attn(x, cond, kv_seqlen)
else:
B, N, C = x.shape
if num_cond_latents is not None and num_cond_latents > 0:
assert shape is not None, "SHOULD pass in the shape"
num_cond_latents_thw = num_cond_latents * (N // shape[0])
x_noise = x[:, num_cond_latents_thw:] # [B, N_noise, C]
output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) # [B, N_noise, C]
output = torch.cat([
torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device),
output_noise
], dim=1).contiguous()
else:
raise NotImplementedError
return output
class LayerNorm_FP32(nn.LayerNorm):
def __init__(self, dim, eps, elementwise_affine):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
out = F.layer_norm(
inputs.float(),
self.normalized_shape,
None if self.weight is None else self.weight.float(),
None if self.bias is None else self.bias.float() ,
self.eps
).to(origin_dtype)
return out
def modulate_fp32(norm_func, x, shift, scale):
# Suppose x is (B, N, D), shift is (B, -1, D), scale is (B, -1, D)
# ensure the modulation params be fp32
assert shift.dtype == torch.float32, scale.dtype == torch.float32
dtype = x.dtype
x = norm_func(x.to(torch.float32))
x = x * (scale + 1) + shift
x = x.to(dtype)
return x
class FinalLayer_FP32(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim):
super().__init__()
self.hidden_size = hidden_size
self.num_patch = num_patch
self.out_channels = out_channels
self.adaln_tembed_dim = adaln_tembed_dim
self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True))
def forward(self, x, t, latent_shape):
# timestep shape: [B, T, C]
assert t.dtype == torch.float32
B, N, C = x.shape
T, _, _ = latent_shape
with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
return x
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, t_embed_dim, frequency_embedding_size=256):
super().__init__()
self.t_embed_dim = t_embed_dim
self.frequency_embedding_size = frequency_embedding_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, t_embed_dim, bias=True),
nn.SiLU(),
nn.Linear(t_embed_dim, t_embed_dim, bias=True),
)
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
freqs = freqs.to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations.
"""
def __init__(self, in_channels, hidden_size):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.y_proj = nn.Sequential(
nn.Linear(in_channels, hidden_size, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, caption):
B, _, N, C = caption.shape
caption = self.y_proj(caption)
return caption
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
B, C, T, H, W = x.shape
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
class LongCatSingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: int,
adaln_tembed_dim: int,
enable_flashattn3: bool = False,
enable_flashattn2: bool = False,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params=None,
cp_split_hw=None
):
super().__init__()
self.hidden_size = hidden_size
# scale and gate modulation
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True)
)
self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False)
self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True)
self.attn = Attention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
self.cross_attn = MultiHeadCrossAttention(
dim=hidden_size,
num_heads=num_heads,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
)
self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio))
def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False):
"""
x: [B, N, C]
y: [1, N_valid_tokens, C]
t: [B, T, C_t]
y_seqlen: [B]; type of a list
latent_shape: latent shape of a single item
"""
x_dtype = x.dtype
B, N, C = x.shape
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
# compute modulation params in fp32
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
# self attn with modulation
x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C)
if kv_cache is not None:
kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device))
attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache)
else:
attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv)
if return_kv:
x_s, kv_cache = attn_outputs
else:
x_s = attn_outputs
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
# cross attn
if not skip_crs_attn:
if kv_cache is not None:
num_cond_latents = None
x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape)
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
if return_kv:
return x, kv_cache
else:
return x
class LongCatVideoTransformer3DModel(torch.nn.Module):
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
depth: int = 48,
num_heads: int = 32,
caption_channels: int = 4096,
mlp_ratio: int = 4,
adaln_tembed_dim: int = 512,
frequency_embedding_size: int = 256,
# default params
patch_size: Tuple[int] = (1, 2, 2),
# attention config
enable_flashattn3: bool = False,
enable_flashattn2: bool = True,
enable_xformers: bool = False,
enable_bsa: bool = False,
bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]},
cp_split_hw: Optional[List[int]] = [1, 1],
text_tokens_zero_pad: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.cp_split_hw = cp_split_hw
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
)
self.blocks = nn.ModuleList(
[
LongCatSingleStreamBlock(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
adaln_tembed_dim=adaln_tembed_dim,
enable_flashattn3=enable_flashattn3,
enable_flashattn2=enable_flashattn2,
enable_xformers=enable_xformers,
enable_bsa=enable_bsa,
bsa_params=bsa_params,
cp_split_hw=cp_split_hw
)
for i in range(depth)
]
)
self.final_layer = FinalLayer_FP32(
hidden_size,
np.prod(self.patch_size),
out_channels,
adaln_tembed_dim,
)
self.gradient_checkpointing = False
self.text_tokens_zero_pad = text_tokens_zero_pad
self.lora_dict = {}
self.active_loras = []
def enable_loras(self, lora_key_list=[]):
self.disable_all_loras()
module_loras = {} # {module_name: [lora1, lora2, ...]}
model_device = next(self.parameters()).device
model_dtype = next(self.parameters()).dtype
for lora_key in lora_key_list:
if lora_key in self.lora_dict:
for lora in self.lora_dict[lora_key].loras:
lora.to(model_device, dtype=model_dtype, non_blocking=True)
module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
if module_name not in module_loras:
module_loras[module_name] = []
module_loras[module_name].append(lora)
self.active_loras.append(lora_key)
for module_name, loras in module_loras.items():
module = self._get_module_by_name(module_name)
if not hasattr(module, 'org_forward'):
module.org_forward = module.forward
module.forward = self._create_multi_lora_forward(module, loras)
def _create_multi_lora_forward(self, module, loras):
def multi_lora_forward(x, *args, **kwargs):
weight_dtype = x.dtype
org_output = module.org_forward(x, *args, **kwargs)
total_lora_output = 0
for lora in loras:
if lora.use_lora:
lx = lora.lora_down(x.to(lora.lora_down.weight.dtype))
lx = lora.lora_up(lx)
lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale
total_lora_output += lora_output
return org_output + total_lora_output
return multi_lora_forward
def _get_module_by_name(self, module_name):
try:
module = self
for part in module_name.split('.'):
module = getattr(module, part)
return module
except AttributeError as e:
raise ValueError(f"Cannot find module: {module_name}, error: {e}")
def disable_all_loras(self):
for name, module in self.named_modules():
if hasattr(module, 'org_forward'):
module.forward = module.org_forward
delattr(module, 'org_forward')
for lora_key, lora_network in self.lora_dict.items():
for lora in lora_network.loras:
lora.to("cpu")
self.active_loras.clear()
def enable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = True
def disable_bsa(self,):
for block in self.blocks:
block.attn.enable_bsa = False
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states,
encoder_attention_mask=None,
num_cond_latents=0,
return_kv=False,
kv_cache_dict={},
skip_crs_attn=False,
offload_kv_cache=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
B, _, T, H, W = hidden_states.shape
N_t = T // self.patch_size[0]
N_h = H // self.patch_size[1]
N_w = W // self.patch_size[2]
assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension."
# expand the shape of timestep from [B] to [B, T]
if len(timestep.shape) == 1:
timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() # [B, T]
timestep[:, :num_cond_latents] = 0
dtype = hidden_states.dtype
hidden_states = hidden_states.to(dtype)
timestep = timestep.to(dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype)
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
if self.text_tokens_zero_pad and encoder_attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None]
encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype)
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) # [1, N_valid_tokens, C]
y_seqlens = encoder_attention_mask.sum(dim=1).tolist() # [B]
else:
y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0]
encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1])
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# hidden_states = rearrange(hidden_states, "B (T H W) C -> B T H W C", T=N_t, H=N_h, W=N_w)
# hidden_states = context_parallel_util.split_cp_2d(hidden_states, seq_dim_hw=(2, 3), split_hw=self.cp_split_hw)
# hidden_states = rearrange(hidden_states, "B T H W C -> B (T H W) C")
# blocks
kv_cache_dict_ret = {}
for i, block in enumerate(self.blocks):
block_outputs = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
x=hidden_states,
y=encoder_hidden_states,
t=t,
y_seqlen=y_seqlens,
latent_shape=(N_t, N_h, N_w),
num_cond_latents=num_cond_latents,
return_kv=return_kv,
kv_cache=kv_cache_dict.get(i, None),
skip_crs_attn=skip_crs_attn,
)
if return_kv:
hidden_states, kv_cache = block_outputs
if offload_kv_cache:
kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu())
else:
kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous())
else:
hidden_states = block_outputs
hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) # [B, N, C=T_p*H_p*W_p*C_out]
# if self.cp_split_hw[0] * self.cp_split_hw[1] > 1:
# hidden_states = context_parallel_util.gather_cp_2d(hidden_states, shape=(N_t, N_h, N_w), split_hw=self.cp_split_hw)
hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) # [B, C_out, H, W]
# cast to float32 for better accuracy
hidden_states = hidden_states.to(torch.float32)
if return_kv:
return hidden_states, kv_cache_dict_ret
else:
return hidden_states
def unpatchify(self, x, N_t, N_h, N_w):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
@staticmethod
def state_dict_converter():
return LongCatVideoTransformer3DModelDictConverter()
class LongCatVideoTransformer3DModelDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
return state_dict

367
diffsynth/models/lora.py Normal file
View File

@@ -0,0 +1,367 @@
import torch
from .sd_unet import SDUNet
from .sdxl_unet import SDXLUNet
from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sd3_dit import SD3DiT
from .flux_dit import FluxDiT
from .hunyuan_dit import HunyuanDiT
from .cog_dit import CogDiT
from .hunyuan_video_dit import HunyuanVideoDiT
class LoRAFromCivitai:
def __init__(self):
self.supported_model_classes = []
self.lora_prefix = []
self.renamed_lora_prefix = {}
self.special_keys = {}
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
for key in state_dict:
if ".lora_up" in key:
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
for special_key in self.special_keys:
target_name = target_name.replace(special_key, self.special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
target_name = target_name[len(lora_prefix):]
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
if model_resource == "diffusers":
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
elif model_resource == "civitai":
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
if isinstance(state_dict_lora, tuple):
state_dict_lora = state_dict_lora[0]
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
fp8=False
if state_dict_model[name].dtype == torch.float8_e4m3fn:
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
fp8=True
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
if fp8:
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
for model_resource in ["diffusers", "civitai"]:
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_)
if isinstance(state_dict_lora_, tuple):
state_dict_lora_ = state_dict_lora_[0]
if len(state_dict_lora_) == 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return lora_prefix, model_resource
except:
pass
return None
class SDLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDUNet, SDTextEncoder]
self.lora_prefix = ["lora_unet_", "lora_te_"]
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
}
class SDXLLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
self.renamed_lora_prefix = {"lora_te2_": "2"}
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "conditioner.embedders.0.transformer.text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
}
class FluxLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [FluxDiT, FluxDiT]
self.lora_prefix = ["lora_unet_", "transformer."]
self.renamed_lora_prefix = {}
self.special_keys = {
"single.blocks": "single_blocks",
"double.blocks": "double_blocks",
"img.attn": "img_attn",
"img.mlp": "img_mlp",
"img.mod": "img_mod",
"txt.attn": "txt_attn",
"txt.mlp": "txt_mlp",
"txt.mod": "txt_mod",
}
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) > 0:
return "", ""
except:
pass
return None
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
self.lora_prefix = ["diffusion_model.", "transformer."]
self.special_keys = {}
class FluxLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, alpha=1.0):
prefix_rename_dict = {
"single_blocks": "lora_unet_single_blocks",
"blocks": "lora_unet_double_blocks",
}
middle_rename_dict = {
"norm.linear": "modulation_lin",
"to_qkv_mlp": "linear1",
"proj_out": "linear2",
"norm1_a.linear": "img_mod_lin",
"norm1_b.linear": "txt_mod_lin",
"attn.a_to_qkv": "img_attn_qkv",
"attn.b_to_qkv": "txt_attn_qkv",
"attn.a_to_out": "img_attn_proj",
"attn.b_to_out": "txt_attn_proj",
"ff_a.0": "img_mlp_0",
"ff_a.2": "img_mlp_2",
"ff_b.0": "txt_mlp_0",
"ff_b.2": "txt_mlp_2",
}
suffix_rename_dict = {
"lora_B.weight": "lora_up.weight",
"lora_A.weight": "lora_down.weight",
}
state_dict_ = {}
for name, param in state_dict.items():
names = name.split(".")
if names[-2] != "lora_A" and names[-2] != "lora_B":
names.pop(-2)
prefix = names[0]
middle = ".".join(names[2:-2])
suffix = ".".join(names[-2:])
block_id = names[1]
if middle not in middle_rename_dict:
continue
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
state_dict_[rename] = param
if rename.endswith("lora_up.weight"):
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
return state_dict_
@staticmethod
def align_to_diffsynth_format(state_dict):
rename_dict = {
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
}
def guess_block_id(name):
names = name.split("_")
for i in names:
if i.isdigit():
return i, name.replace(f"_{i}_", "_blockid_")
return None, None
state_dict_ = {}
for name, param in state_dict.items():
block_id, source_name = guess_block_id(name)
if source_name in rename_dict:
target_name = rename_dict[source_name]
target_name = target_name.replace(".blockid.", f".{block_id}.")
state_dict_[target_name] = param
else:
state_dict_[name] = param
return state_dict_
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

File diff suppressed because it is too large Load Diff

View File

@@ -1,371 +0,0 @@
from dataclasses import dataclass
from typing import NamedTuple, Protocol, Tuple
import torch
from torch import nn
from enum import Enum
class VideoPixelShape(NamedTuple):
"""
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
"""
batch: int
frames: int
height: int
width: int
fps: float
class SpatioTemporalScaleFactors(NamedTuple):
"""
Describes the spatiotemporal downscaling between decoded video space and
the corresponding VAE latent grid.
"""
time: int
width: int
height: int
@classmethod
def default(cls) -> "SpatioTemporalScaleFactors":
return cls(time=8, width=32, height=32)
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
class VideoLatentShape(NamedTuple):
"""
Shape of the tensor representing video in VAE latent space.
The latent representation is a 5D tensor with dimensions ordered as
(batch, channels, frames, height, width). Spatial and temporal dimensions
are downscaled relative to pixel space according to the VAE's scale factors.
"""
batch: int
channels: int
frames: int
height: int
width: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
@staticmethod
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
return VideoLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
height=shape[3],
width=shape[4],
)
def mask_shape(self) -> "VideoLatentShape":
return self._replace(channels=1)
@staticmethod
def from_pixel_shape(
shape: VideoPixelShape,
latent_channels: int = 128,
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
) -> "VideoLatentShape":
frames = (shape.frames - 1) // scale_factors[0] + 1
height = shape.height // scale_factors[1]
width = shape.width // scale_factors[2]
return VideoLatentShape(
batch=shape.batch,
channels=latent_channels,
frames=frames,
height=height,
width=width,
)
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
return self._replace(
channels=3,
frames=(self.frames - 1) * scale_factors.time + 1,
height=self.height * scale_factors.height,
width=self.width * scale_factors.width,
)
class AudioLatentShape(NamedTuple):
"""
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
"""
batch: int
channels: int
frames: int
mel_bins: int
def to_torch_shape(self) -> torch.Size:
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
def mask_shape(self) -> "AudioLatentShape":
return self._replace(channels=1, mel_bins=1)
@staticmethod
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
return AudioLatentShape(
batch=shape[0],
channels=shape[1],
frames=shape[2],
mel_bins=shape[3],
)
@staticmethod
def from_duration(
batch: int,
duration: float,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
return AudioLatentShape(
batch=batch,
channels=channels,
frames=round(duration * latents_per_second),
mel_bins=mel_bins,
)
@staticmethod
def from_video_pixel_shape(
shape: VideoPixelShape,
channels: int = 8,
mel_bins: int = 16,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
) -> "AudioLatentShape":
return AudioLatentShape.from_duration(
batch=shape.batch,
duration=float(shape.frames) / float(shape.fps),
channels=channels,
mel_bins=mel_bins,
sample_rate=sample_rate,
hop_length=hop_length,
audio_latent_downsample_factor=audio_latent_downsample_factor,
)
@dataclass(frozen=True)
class LatentState:
"""
State of latents during the diffusion denoising process.
Attributes:
latent: The current noisy latent tensor being denoised.
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
positions: Positional indices for each latent element, used for positional embeddings.
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
"""
latent: torch.Tensor
denoise_mask: torch.Tensor
positions: torch.Tensor
clean_latent: torch.Tensor
def clone(self) -> "LatentState":
return LatentState(
latent=self.latent.clone(),
denoise_mask=self.denoise_mask.clone(),
positions=self.positions.clone(),
clean_latent=self.clean_latent.clone(),
)
class NormType(Enum):
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
GROUP = "group"
PIXEL = "pixel"
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply RMS normalization along the configured dimension.
"""
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
# Normalize by the root-mean-square (RMS).
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
def build_normalization_layer(
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
) -> nn.Module:
"""
Create a normalization layer based on the normalization type.
Args:
in_channels: Number of input channels
num_groups: Number of groups for group normalization
normtype: Type of normalization: "group" or "pixel"
Returns:
A normalization layer
"""
if normtype == NormType.GROUP:
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if normtype == NormType.PIXEL:
return PixelNorm(dim=1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
"""Root-mean-square (RMS) normalize `x` over its last dimension.
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
shape and forwards `weight` and `eps`.
"""
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
@dataclass(frozen=True)
class Modality:
"""
Input data for a single modality (video or audio) in the transformer.
Bundles the latent tokens, timestep embeddings, positional information,
and text conditioning context for processing by the diffusion transformer.
"""
latent: (
torch.Tensor
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
positions: (
torch.Tensor
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
context: torch.Tensor
enabled: bool = True
context_mask: torch.Tensor | None = None
def to_denoised(
sample: torch.Tensor,
velocity: torch.Tensor,
sigma: float | torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoising velocity to denoised sample.
Returns:
Denoised sample
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype)
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
class Patchifier(Protocol):
"""
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
"""
def patchify(
self,
latents: torch.Tensor,
) -> torch.Tensor:
...
"""
Convert latent tensors into flattened patch tokens.
Args:
latents: Latent tensor to patchify.
Returns:
Flattened patch tokens tensor.
"""
def unpatchify(
self,
latents: torch.Tensor,
output_shape: AudioLatentShape | VideoLatentShape,
) -> torch.Tensor:
"""
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
Args:
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
VideoLatentShape.
Returns:
Dense latent tensor restored from the flattened representation.
"""
@property
def patch_size(self) -> Tuple[int, int, int]:
...
"""
Returns the patch size as a tuple of (temporal, height, width) dimensions
"""
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: torch.device | None = None,
) -> torch.Tensor:
...
"""
Compute metadata describing where each latent patch resides within the
grid specified by `output_shape`.
Args:
output_shape: Target grid layout for the patches.
device: Target device for the returned tensor.
Returns:
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
"""
def get_pixel_coords(
latent_coords: torch.Tensor,
scale_factors: SpatioTemporalScaleFactors,
causal_fix: bool = False,
) -> torch.Tensor:
"""
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
Args:
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
per axis.
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
that treat frame zero differently still yield non-negative timestamps.
"""
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
pixel_coords = latent_coords * scale_tensor
if causal_fix:
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords

File diff suppressed because it is too large Load Diff

View File

@@ -1,366 +0,0 @@
import torch
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
FeedForward)
from .ltx2_common import rms_norm
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
def __init__(self):
config = Gemma3Config(
**{
"architectures": ["Gemma3ForConditionalGeneration"],
"boi_token_index": 255999,
"dtype": "bfloat16",
"eoi_token_index": 256000,
"eos_token_id": [1, 106],
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"text_config": {
"_sliding_window_pattern": 6,
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": None,
"cache_implementation": "hybrid",
"dtype": "bfloat16",
"final_logit_softcapping": None,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 3840,
"initializer_range": 0.02,
"intermediate_size": 15360,
"layer_types": [
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
],
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 16,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": {
"factor": 8.0,
"rope_type": "linear"
},
"rope_theta": 1000000,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"use_bidirectional_attention": False,
"use_cache": True,
"vocab_size": 262208
},
"transformers_version": "4.57.3",
"vision_config": {
"attention_dropout": 0.0,
"dtype": "bfloat16",
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"vision_use_head": False
}
})
super().__init__(config)
class LTXVGemmaTokenizer:
"""
Tokenizer wrapper for Gemma models compatible with LTXV processes.
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
ensuring correct settings and output formatting for downstream consumption.
"""
def __init__(self, tokenizer_path: str, max_length: int = 1024):
"""
Initialize the tokenizer.
Args:
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
"""
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, local_files_only=True, model_max_length=max_length
)
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_length = max_length
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
"""
Tokenize the given text and return token IDs and attention weights.
Args:
text (str): The input string to tokenize.
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
If False (default), omits the indices.
Returns:
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
A dictionary with a "gemma" key mapping to:
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
Example:
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
>>> tokenizer.tokenize_with_weights("hello world")
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
"""
text = text.strip()
encoded = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
input_ids = encoded.input_ids
attention_mask = encoded.attention_mask
tuples = [
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
]
out = {"gemma": tuples}
if not return_word_ids:
# Return only (token_id, attention_mask) pairs, omitting token position
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
return out
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
"""
Feature extractor module for Gemma models.
This module applies a single linear projection to the input tensor.
It expects a flattened feature tensor of shape (batch_size, 3840*49).
The linear layer maps this to a (batch_size, 3840) embedding.
Attributes:
aggregate_embed (torch.nn.Linear): Linear projection layer.
"""
def __init__(self) -> None:
"""
Initialize the GemmaFeaturesExtractorProjLinear module.
The input dimension is expected to be 3840 * 49, and the output is 3840.
"""
super().__init__()
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the feature extractor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
Returns:
torch.Tensor: Output tensor of shape (batch_size, 3840).
"""
return self.aggregate_embed(x)
class _BasicTransformerBlock1D(torch.nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
):
super().__init__()
self.attn1 = Attention(
query_dim=dim,
heads=heads,
dim_head=dim_head,
rope_type=rope_type,
)
self.ff = FeedForward(
dim,
dim_out=dim,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
pe: torch.Tensor | None = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Normalization Before Self-Attention
norm_hidden_states = rms_norm(hidden_states)
norm_hidden_states = norm_hidden_states.squeeze(1)
# 2. Self-Attention
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Normalization before Feed-Forward
norm_hidden_states = rms_norm(hidden_states)
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Embeddings1DConnector(torch.nn.Module):
"""
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
layers, and register usage.
Args:
attention_head_dim (int): Dimension of each attention head (default=128).
num_attention_heads (int): Number of attention heads (default=30).
num_layers (int): Number of transformer layers (default=2).
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
register replacement. (default=128)
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
double_precision_rope (bool): Use double precision rope calculation (default=False).
"""
_supports_gradient_checkpointing = True
def __init__(
self,
attention_head_dim: int = 128,
num_attention_heads: int = 30,
num_layers: int = 2,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list[int] | None = [4096],
causal_temporal_positioning: bool = False,
num_learnable_registers: int | None = 128,
rope_type: LTXRopeType = LTXRopeType.SPLIT,
double_precision_rope: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = (
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
)
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = torch.nn.ModuleList(
[
_BasicTransformerBlock1D(
dim=self.inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rope_type=rope_type,
)
for _ in range(num_layers)
]
)
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = torch.nn.Parameter(
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
)
def _replace_padded_with_learnable_registers(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
f"{self.num_learnable_registers}."
)
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
non_zero_nums = non_zero_hidden_states.shape[1]
pad_length = hidden_states.shape[1] - non_zero_nums
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
attention_mask = torch.full_like(
attention_mask,
0.0,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
return hidden_states, attention_mask
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of Embeddings1DConnector.
Args:
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
Returns:
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
"""
if self.num_learnable_registers:
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
indices_grid = indices_grid[None, None, :]
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
freqs_cis = precompute_freqs_cis(
indices_grid=indices_grid,
dim=self.inner_dim,
out_dtype=hidden_states.dtype,
theta=self.positional_embedding_theta,
max_pos=self.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
rope_type=self.rope_type,
freq_grid_generator=freq_grid_generator,
)
for block in self.transformer_1d_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
hidden_states = rms_norm(hidden_states)
return hidden_states, attention_mask
class LTX2TextEncoderPostModules(torch.nn.Module):
def __init__(self,):
super().__init__()
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
self.embeddings_connector = Embeddings1DConnector()
self.audio_embeddings_connector = Embeddings1DConnector()

View File

@@ -1,313 +0,0 @@
import math
from typing import Optional, Tuple
import torch
from einops import rearrange
import torch.nn.functional as F
from .ltx2_video_vae import LTX2VideoEncoder
class PixelShuffleND(torch.nn.Module):
"""
N-dimensional pixel shuffle operation for upsampling tensors.
Args:
dims (int): Number of dimensions to apply pixel shuffle to.
- 1: Temporal (e.g., frames)
- 2: Spatial (e.g., height and width)
- 3: Spatiotemporal (e.g., depth, height, width)
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
For dims=1, only the first value is used.
For dims=2, the first two values are used.
For dims=3, all three values are used.
The input tensor is rearranged so that the channel dimension is split into
smaller channels and upscaling factors, and the upscaling factors are moved
into the corresponding spatial/temporal dimensions.
Note:
This operation is equivalent to the patchifier operation in for the models. Consider
using this class instead.
"""
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
else:
raise ValueError(f"Unsupported dims: {self.dims}")
class ResBlock(torch.nn.Module):
"""
Residual block with two convolutional layers, group normalization, and SiLU activation.
Args:
channels (int): Number of input and output channels.
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
if not specified.
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
"""
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class BlurDownsample(torch.nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
super().__init__()
assert dims in (2, 3)
assert isinstance(stride, int)
assert stride >= 1
assert kernel_size >= 3
assert kernel_size % 2 == 1
self.dims = dims
self.stride = stride
self.kernel_size = kernel_size
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
# The 2D kernel is constructed as the outer product and normalized.
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
if self.dims == 2:
return self._apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self._apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
c = x2d.shape[1]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
return x2d
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
return mapping[float(scale)]
class SpatialRationalResampler(torch.nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
Args:
mid_channels (`int`): Number of intermediate channels for the convolution layer
scale (`float`): Spatial scaling factor. Supported values are:
- 0.75: Downsample by 3/4 (reduce spatial size)
- 1.5: Upsample by 3/2 (increase spatial size)
- 2.0: Upsample by 2x (double spatial size)
- 4.0: Upsample by 4x (quadruple spatial size)
Any other value will raise a ValueError.
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class LTX2LatentUpsampler(torch.nn.Module):
"""
Model to upsample VAE latents spatially and/or temporally.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
spatial_scale (`float`): Scale factor for spatial upsampling
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
else:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, _, f, _, _ = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
# remove the first frame after upsampling.
# This is done because the first frame encodes one pixel frame.
x = x[:, :, 1:, :, :]
elif isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
"""
Apply upsampling to the latent representation using the provided upsampler,
with normalization and un-normalization based on the video encoder's per-channel statistics.
Args:
latent: Input latent tensor of shape [B, C, F, H, W].
video_encoder: VideoEncoder with per_channel_statistics for normalization.
upsampler: LTX2LatentUpsampler module to perform upsampling.
Returns:
torch.Tensor: Upsampled and re-normalized latent tensor.
"""
latent = video_encoder.per_channel_statistics.un_normalize(latent)
latent = upsampler(latent)
latent = video_encoder.per_channel_statistics.normalize(latent)
return latent

File diff suppressed because it is too large Load Diff

View File

@@ -1,113 +0,0 @@
from ..core.loader import load_model, hash_model_file
from ..core.vram import AutoWrappedModule
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS, VERSION_CHECKER_MAPS
import importlib, json, torch
class ModelPool:
def __init__(self):
self.model = []
self.model_name = []
self.model_path = []
def import_model_class(self, model_class):
split = model_class.rfind(".")
model_resource, model_class = model_class[:split], model_class[split+1:]
model_class = importlib.import_module(model_resource).__getattribute__(model_class)
return model_class
def need_to_enable_vram_management(self, vram_config):
return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None
def fetch_module_map(self, model_class, vram_config):
if self.need_to_enable_vram_management(vram_config):
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
vram_module_map = VRAM_MANAGEMENT_MODULE_MAPS[model_class] if model_class not in VERSION_CHECKER_MAPS else VERSION_CHECKER_MAPS[model_class]()
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in vram_module_map.items()}
else:
module_map = {self.import_model_class(model_class): AutoWrappedModule}
else:
module_map = None
return module_map
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
model_class = self.import_model_class(config["model_class"])
model_config = config.get("extra_kwargs", {})
if "state_dict_converter" in config:
state_dict_converter = self.import_model_class(config["state_dict_converter"])
else:
state_dict_converter = None
module_map = self.fetch_module_map(config["model_class"], vram_config)
model = load_model(
model_class, path, model_config,
vram_config["computation_dtype"], vram_config["computation_device"],
state_dict_converter,
use_disk_map=True,
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
state_dict=state_dict,
)
return model
def default_vram_config(self):
vram_config = {
"offload_dtype": None,
"offload_device": None,
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cpu",
"computation_dtype": torch.bfloat16,
"computation_device": "cpu",
}
return vram_config
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
print(f"Loading models from: {json.dumps(path, indent=4)}")
if vram_config is None:
vram_config = self.default_vram_config()
model_hash = hash_model_file(path)
loaded = False
for config in MODEL_CONFIGS:
if config["model_hash"] == model_hash:
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
if clear_parameters: self.clear_parameters(model)
self.model.append(model)
model_name = config["model_name"]
self.model_name.append(model_name)
self.model_path.append(path)
model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")}
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
loaded = True
if not loaded:
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
def fetch_model(self, model_name, index=None):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available. This is not an error.")
model = None
elif len(fetched_models) == 1:
print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
model = fetched_models[0]
else:
if index is None:
model = fetched_models[0]
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
elif isinstance(index, int):
model = fetched_models[:index]
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.")
else:
model = fetched_models
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
return model
def clear_parameters(self, model: torch.nn.Module):
for name, module in model.named_children():
self.clear_parameters(module)
for name, param in model.named_parameters(recurse=False):
setattr(model, name, None)

View File

@@ -0,0 +1,448 @@
import os, torch, json, importlib
from typing import List
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .lora import get_lora_loaders
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from .sd3_dit import SD3DiT
from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sdxl_controlnet import SDXLControlNetUnion
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .flux_ipadapter import FluxIpAdapter
from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
with init_weights_on_device():
model= model_class(**extra_kwargs)
model.load_state_dict(model_state_dict, assign=True)
model = model.to(dtype=torch_dtype, device=device)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
else:
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
try:
model = model.to(device=device)
except:
pass
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config and "_class_name" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if not isinstance(file_path, str) or os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(downloaded_files + file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
if isinstance(file_path, list):
for file_path_ in file_path:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
print(f"Loading models from: {file_path}")
if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if isinstance(file_path, list):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
elif os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=device, torch_dtype=torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
for file_path in file_path_list:
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

Some files were not shown because too many files have changed in this diff Show More