mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-28 04:07:49 +00:00
Compare commits
186 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2cefc20ed6 | ||
|
|
02a4c8df9f | ||
|
|
582e33ad51 | ||
|
|
491bbf5369 | ||
|
|
0c92f3b2cc | ||
|
|
427232cbc0 | ||
|
|
2899283c01 | ||
|
|
9cff769fbd | ||
|
|
23e33273f1 | ||
|
|
f191353cf4 | ||
|
|
66a094fc84 | ||
|
|
3681adc5ac | ||
|
|
7434ec8fcd | ||
|
|
0699212665 | ||
|
|
f47de78b59 | ||
|
|
5fdc8039ec | ||
|
|
46d4616e23 | ||
|
|
2e597335be | ||
|
|
d346300162 | ||
|
|
1df7387f1b | ||
|
|
75d62a02d1 | ||
|
|
9db26879df | ||
|
|
7beac7972e | ||
|
|
72cac18d3e | ||
|
|
9f8112ec34 | ||
|
|
d9fad821b2 | ||
|
|
c0889c2564 | ||
|
|
913591c13e | ||
|
|
aaf13d6e4a | ||
|
|
90c07fec61 | ||
|
|
cc6c3c0807 | ||
|
|
ce2476ab9b | ||
|
|
9e70c49317 | ||
|
|
bf1c99645b | ||
|
|
c2478ff284 | ||
|
|
a60bf3cd5f | ||
|
|
34231907d0 | ||
|
|
840dab58cd | ||
|
|
d5ceca0663 | ||
|
|
8cf3422688 | ||
|
|
6f743fc4b6 | ||
|
|
991b133bff | ||
|
|
3b010043de | ||
|
|
088ea29e6e | ||
|
|
b8b135ff73 | ||
|
|
2872fdaf48 | ||
|
|
9853f83454 | ||
|
|
fd6e661203 | ||
|
|
c087f68d74 | ||
|
|
b6620f3dde | ||
|
|
3228c3e085 | ||
|
|
6cc5fd6d1e | ||
|
|
4f6d5e7074 | ||
|
|
6a999e1127 | ||
|
|
e3d89cec0c | ||
|
|
1b6e96a820 | ||
|
|
e38ccf4c2f | ||
|
|
010c801081 | ||
|
|
edc9272e55 | ||
|
|
405ca6be33 | ||
|
|
c06ea2271a | ||
|
|
0692e8b1e1 | ||
|
|
aa23356420 | ||
|
|
00a610e5ad | ||
|
|
2e39dcc0d3 | ||
|
|
03d3a26f6f | ||
|
|
309fa9cf51 | ||
|
|
65aab8adea | ||
|
|
3d48b287a3 | ||
|
|
29cebf0bec | ||
|
|
95a0f0bedc | ||
|
|
77e0617861 | ||
|
|
469a0405a1 | ||
|
|
46f191ffe7 | ||
|
|
ec7ac20def | ||
|
|
3f410b0b77 | ||
|
|
8e06cac0df | ||
|
|
e5099f4e74 | ||
|
|
447adef472 | ||
|
|
a849b05e5a | ||
|
|
b048f1b1de | ||
|
|
f7848f9560 | ||
|
|
236b56d285 | ||
|
|
42a717054a | ||
|
|
263166768e | ||
|
|
7a45b7efa7 | ||
|
|
54ed532e3e | ||
|
|
05e2028c5d | ||
|
|
79249063b8 | ||
|
|
31ebec7a72 | ||
|
|
919d399fdb | ||
|
|
32a7a1487d | ||
|
|
8c2671ce40 | ||
|
|
5d1005a7c8 | ||
|
|
b84f906964 | ||
|
|
7c0520d029 | ||
|
|
9d09121fbc | ||
|
|
7f2a5424d4 | ||
|
|
00830f0ecd | ||
|
|
fd7737af7d | ||
|
|
f2130c4c25 | ||
|
|
4f40683fd8 | ||
|
|
5fc9e53eec | ||
|
|
27e3cea285 | ||
|
|
ee770fa68f | ||
|
|
9cb4aa16eb | ||
|
|
92d990629f | ||
|
|
ba58f1bc0b | ||
|
|
02fcfd530f | ||
|
|
095e8a3de8 | ||
|
|
e17ad83fb5 | ||
|
|
e7c41151ec | ||
|
|
7f4ba62d4f | ||
|
|
71b17a3a53 | ||
|
|
d46b8b8fd7 | ||
|
|
a671070a28 | ||
|
|
4600d5351b | ||
|
|
75bba5b8e5 | ||
|
|
8d1d1536d3 | ||
|
|
a7050a185b | ||
|
|
d345541c2d | ||
|
|
bd028e4c66 | ||
|
|
d6f4fb67cc | ||
|
|
4378b540cf | ||
|
|
39ddb7c3e3 | ||
|
|
344cbd3286 | ||
|
|
d4ba173b53 | ||
|
|
c56ce656b2 | ||
|
|
9377214518 | ||
|
|
900a1c095f | ||
|
|
7e97a96840 | ||
|
|
69f272d7ba | ||
|
|
a653554bd9 | ||
|
|
6a25006544 | ||
|
|
8cfe4820f6 | ||
|
|
c8021d4224 | ||
|
|
3a64cc27b5 | ||
|
|
2edc485ec1 | ||
|
|
a6d6553cee | ||
|
|
45feef9413 | ||
|
|
105fe3961c | ||
|
|
d381c7b186 | ||
|
|
5e8334c0bf | ||
|
|
2ea8a16afb | ||
|
|
aa054db1c7 | ||
|
|
07d70a6a56 | ||
|
|
747572e62c | ||
|
|
72ed76e89e | ||
|
|
a403cb04f3 | ||
|
|
ed71184854 | ||
|
|
dfbf43e463 | ||
|
|
7d7d72dcfe | ||
|
|
540c036988 | ||
|
|
58f89ceec9 | ||
|
|
4e3a184199 | ||
|
|
22e4ae99e8 | ||
|
|
75ab786afc | ||
|
|
e5c72ba1f2 | ||
|
|
66873d7d64 | ||
|
|
a0d1d5bcea | ||
|
|
fa0fa95bb6 | ||
|
|
41ea2f811a | ||
|
|
ec352cfce2 | ||
|
|
aade874241 | ||
|
|
c01eb653d7 | ||
|
|
892f80c265 | ||
|
|
2e487a2c55 | ||
|
|
a34e3ba338 | ||
|
|
c414f4cb12 | ||
|
|
d91c603875 | ||
|
|
7f899dcfca | ||
|
|
5f12fd4346 | ||
|
|
a7197f846b | ||
|
|
ac81fa7a9f | ||
|
|
091df1f1e7 | ||
|
|
a9fbfa108f | ||
|
|
44a8bf4143 | ||
|
|
3da8aa257b | ||
|
|
884dd749a0 | ||
|
|
c697591d6e | ||
|
|
0b706e03e7 | ||
|
|
447e75cd06 | ||
|
|
7f76c8809c | ||
|
|
cde1f81df6 | ||
|
|
c21ed1e478 | ||
|
|
a8cb4a21d1 |
31
README.md
31
README.md
@@ -9,13 +9,17 @@
|
||||
<a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
|
||||
Document: https://diffsynth-studio.readthedocs.io/zh-cn/latest/index.html
|
||||
|
||||
## Introduction
|
||||
|
||||
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
||||
|
||||
Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
* [CogVideo](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||
@@ -32,6 +36,26 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
## News
|
||||
|
||||
- **February 17, 2024** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
||||
|
||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
- **December 18, 2024** We propose ArtAug, an approach designed to improve text-to-image synthesis models through synthesis-understanding interactions. We have trained an ArtAug enhancement module for FLUX.1-dev in the format of LoRA. This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, leading to an improvement in the quality of generated images.
|
||||
- Paper: https://arxiv.org/abs/2412.12888
|
||||
- Examples: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
|
||||
|
||||
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
|
||||
|
||||
- **October 8, 2024.** We release the extended LoRA based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
|
||||
|
||||
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
|
||||
- Text to video
|
||||
- Video editing
|
||||
@@ -137,10 +161,11 @@ https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
|
||||
|
||||
#### Long Video Synthesis
|
||||
|
||||
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||
We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
||||
|
||||
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
||||
|
||||
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
|
||||
|
||||
#### Toon Shading
|
||||
|
||||
@@ -164,7 +189,7 @@ LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||
|
||||
|FLUX|Stable Diffusion 3|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
|
||||
|Kolors|Hunyuan-DiT|
|
||||
|-|-|
|
||||
|
||||
390
apps/gradio/entity_level_control.py
Normal file
390
apps/gradio/entity_level_control.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import random
|
||||
import json
|
||||
import gradio as gr
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
|
||||
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
|
||||
with open(example_json, 'r') as f:
|
||||
examples = json.load(f)['examples']
|
||||
|
||||
for idx in range(len(examples)):
|
||||
example_id = examples[idx]['example_id']
|
||||
entity_prompts = examples[idx]['local_prompt_list']
|
||||
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
|
||||
def create_canvas_data(background, masks):
|
||||
if background.shape[-1] == 3:
|
||||
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
||||
layers = []
|
||||
for mask in masks:
|
||||
if mask is not None:
|
||||
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
||||
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
||||
layer[..., -1] = mask_single_channel
|
||||
layers.append(layer)
|
||||
else:
|
||||
layers.append(np.zeros_like(background))
|
||||
|
||||
composite = background.copy()
|
||||
for layer in layers:
|
||||
if layer.size > 0:
|
||||
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
||||
return {
|
||||
"background": background,
|
||||
"layers": layers,
|
||||
"composite": composite,
|
||||
}
|
||||
|
||||
def load_example(load_example_button):
|
||||
example_idx = int(load_example_button.split()[-1]) - 1
|
||||
example = examples[example_idx]
|
||||
result = [
|
||||
50,
|
||||
example["global_prompt"],
|
||||
example["negative_prompt"],
|
||||
example["seed"],
|
||||
*example["local_prompt_list"],
|
||||
]
|
||||
num_entities = len(example["local_prompt_list"])
|
||||
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
||||
masks = []
|
||||
for mask in example["mask_lists"]:
|
||||
mask_single_channel = np.array(mask.convert("L"))
|
||||
masks.append(mask_single_channel)
|
||||
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
||||
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
||||
masks.append(blank_mask)
|
||||
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
||||
canvas_data_list = []
|
||||
for mask in masks:
|
||||
canvas_data = create_canvas_data(background, [mask])
|
||||
canvas_data_list.append(canvas_data)
|
||||
result.extend(canvas_data_list)
|
||||
return result
|
||||
|
||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
||||
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
||||
print(f'save to {save_dir}')
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, mask in enumerate(masks):
|
||||
save_path = os.path.join(save_dir, f'{i}.png')
|
||||
mask.save(save_path)
|
||||
sample = {
|
||||
"global_prompt": global_prompt,
|
||||
"mask_prompts": mask_prompts,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
|
||||
json.dump(sample, f, indent=4)
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
if mask is None:
|
||||
continue
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
if mask_bbox is None:
|
||||
continue
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
return result
|
||||
|
||||
config = {
|
||||
"model_config": {
|
||||
"FLUX": {
|
||||
"model_folder": "models/FLUX",
|
||||
"pipeline_class": FluxImagePipeline,
|
||||
"default_parameters": {
|
||||
"cfg_scale": 3.0,
|
||||
"embedded_guidance": 3.5,
|
||||
"num_inference_steps": 30,
|
||||
}
|
||||
},
|
||||
},
|
||||
"max_num_painter_layers": 8,
|
||||
"max_num_model_cache": 1,
|
||||
}
|
||||
|
||||
model_dict = {}
|
||||
|
||||
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
|
||||
global model_dict
|
||||
model_key = f"{model_type}:{model_path}"
|
||||
if model_key in model_dict:
|
||||
return model_dict[model_key]
|
||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control",
|
||||
),
|
||||
lora_alpha=1,
|
||||
)
|
||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||
model_dict[model_key] = model_manager, pipe
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
||||
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
||||
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
||||
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
||||
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
||||
"""
|
||||
)
|
||||
|
||||
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
||||
main_interface = gr.Column(visible=False)
|
||||
|
||||
def initialize_model():
|
||||
try:
|
||||
load_model()
|
||||
return {
|
||||
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f'Failed to load model with error: {e}')
|
||||
return {
|
||||
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
||||
main_interface: gr.update(visible=True),
|
||||
}
|
||||
|
||||
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
||||
|
||||
with main_interface:
|
||||
with gr.Row():
|
||||
local_prompt_list = []
|
||||
canvas_list = []
|
||||
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
||||
with gr.Column(scale=382, min_width=100):
|
||||
model_type = gr.State('FLUX')
|
||||
model_path = gr.State('FLUX.1-dev')
|
||||
with gr.Accordion(label="Global prompt"):
|
||||
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
|
||||
with gr.Accordion(label="Inference Options", open=True):
|
||||
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
||||
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
||||
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
||||
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
|
||||
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
||||
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
||||
with gr.Accordion(label="Inpaint Input Image", open=False):
|
||||
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
||||
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
||||
|
||||
with gr.Column():
|
||||
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
||||
send_input_to_painter = gr.Button(value="Set as painter's background")
|
||||
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
||||
def reset_input_image(input_image):
|
||||
return None
|
||||
|
||||
with gr.Column(scale=618, min_width=100):
|
||||
with gr.Accordion(label="Entity Painter"):
|
||||
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||
canvas = gr.ImageEditor(
|
||||
canvas_size=(512, 512),
|
||||
sources=None,
|
||||
layers=False,
|
||||
interactive=True,
|
||||
image_mode="RGBA",
|
||||
brush=gr.Brush(
|
||||
default_size=50,
|
||||
default_color="#000000",
|
||||
colors=["#000000"],
|
||||
),
|
||||
label="Entity Mask Painter",
|
||||
key=f"canvas_{painter_layer_id}",
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
||||
def resize_canvas(height, width, canvas):
|
||||
h, w = canvas["background"].shape[:2]
|
||||
if h != height or width != w:
|
||||
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
||||
else:
|
||||
return canvas
|
||||
local_prompt_list.append(local_prompt)
|
||||
canvas_list.append(canvas)
|
||||
with gr.Accordion(label="Results"):
|
||||
run_button = gr.Button(value="Generate", variant="primary")
|
||||
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
output_to_painter_button = gr.Button(value="Set as painter's background")
|
||||
with gr.Column():
|
||||
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
||||
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
||||
real_output = gr.State(None)
|
||||
mask_out = gr.State(None)
|
||||
|
||||
@gr.on(
|
||||
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
||||
outputs=[output_image, real_output, mask_out],
|
||||
triggers=run_button.click
|
||||
)
|
||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
||||
_, pipe = load_model(model_type, model_path)
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"progress_bar_cmd": progress.tqdm,
|
||||
}
|
||||
if isinstance(pipe, FluxImagePipeline):
|
||||
input_params["embedded_guidance"] = embedded_guidance
|
||||
if input_image is not None:
|
||||
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
||||
input_params["enable_eligen_inpaint"] = True
|
||||
|
||||
local_prompt_list, canvas_list = (
|
||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||
)
|
||||
local_prompts, masks = [], []
|
||||
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
||||
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
||||
local_prompts.append(local_prompt)
|
||||
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
||||
entity_masks = None if len(masks) == 0 else masks
|
||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
||||
input_params.update({
|
||||
"eligen_entity_prompts": entity_prompts,
|
||||
"eligen_entity_masks": entity_masks,
|
||||
})
|
||||
torch.manual_seed(seed)
|
||||
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
||||
image = pipe(**input_params)
|
||||
masks = [mask.resize(image.size) for mask in masks]
|
||||
image_with_mask = visualize_masks(image, masks, local_prompts)
|
||||
|
||||
real_output = gr.State(image)
|
||||
mask_out = gr.State(image_with_mask)
|
||||
|
||||
if return_with_mask:
|
||||
return image_with_mask, real_output, mask_out
|
||||
return image, real_output, mask_out
|
||||
|
||||
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
||||
def send_input_to_painter_background(input_image, *canvas_list):
|
||||
if input_image is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = input_image.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
||||
def send_output_to_painter_background(real_output, *canvas_list):
|
||||
if real_output is None:
|
||||
return tuple(canvas_list)
|
||||
for canvas in canvas_list:
|
||||
h, w = canvas["background"].shape[:2]
|
||||
canvas["background"] = real_output.value.resize((w, h))
|
||||
return tuple(canvas_list)
|
||||
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
||||
def show_output(return_with_mask, real_output, mask_out):
|
||||
if return_with_mask:
|
||||
return mask_out.value
|
||||
else:
|
||||
return real_output.value
|
||||
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
||||
def send_output_to_pipe_input(real_output):
|
||||
return real_output.value
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## Examples")
|
||||
for i in range(0, len(examples), 2):
|
||||
with gr.Row():
|
||||
if i < len(examples):
|
||||
example = examples[i]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
|
||||
if i + 1 < len(examples):
|
||||
example = examples[i + 1]
|
||||
with gr.Column():
|
||||
example_image = gr.Image(
|
||||
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
||||
label=example["description"],
|
||||
interactive=False,
|
||||
width=1024,
|
||||
height=512
|
||||
)
|
||||
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
||||
load_example_button.click(
|
||||
load_example,
|
||||
inputs=[load_example_button],
|
||||
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
||||
)
|
||||
app.config["show_progress"] = "hidden"
|
||||
app.launch()
|
||||
@@ -33,16 +33,28 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
|
||||
from ..models.flux_dit import FluxDiT
|
||||
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
|
||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from ..models.cog_dit import CogDiT
|
||||
|
||||
from ..models.omnigen import OmniGenTransformer
|
||||
|
||||
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
|
||||
from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
|
||||
from ..models.wanx_vae import WanXVideoVAE
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -72,13 +84,32 @@ model_loader_configs = [
|
||||
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
||||
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
||||
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
|
||||
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
||||
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
||||
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
|
||||
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
||||
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wanxvideo_vae"], [WanXVideoVAE], "civitai")
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -87,9 +118,12 @@ huggingface_model_loader_configs = [
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
||||
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -110,6 +144,117 @@ preset_models_on_huggingface = {
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
# Stable Diffusion
|
||||
"StableDiffusion_v15": [
|
||||
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
"DreamShaper_8": [
|
||||
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
||||
],
|
||||
# Textual Inversion
|
||||
"TextualInversion_VeryBadImageNegative_v1.3": [
|
||||
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
||||
],
|
||||
# Stable Diffusion XL
|
||||
"StableDiffusionXL_v1": [
|
||||
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"BluePencilXL_v200": [
|
||||
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
||||
],
|
||||
"StableDiffusionXL_Turbo": [
|
||||
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
||||
],
|
||||
# Stable Diffusion 3
|
||||
"StableDiffusion3": [
|
||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
"StableDiffusion3_without_T5": [
|
||||
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
||||
],
|
||||
# ControlNet
|
||||
"ControlNet_v11f1p_sd15_depth": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11p_sd15_softedge": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
||||
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_v11f1e_sd15_tile": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
||||
],
|
||||
"ControlNet_v11p_sd15_lineart": [
|
||||
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
||||
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
||||
],
|
||||
"ControlNet_union_sdxl_promax": [
|
||||
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
"AnimateDiff_xl_beta": [
|
||||
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
||||
],
|
||||
|
||||
# Qwen Prompt
|
||||
"QwenPrompt": [
|
||||
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
],
|
||||
# Beautiful Prompt
|
||||
"BeautifulPrompt": [
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
# Omost prompt
|
||||
"OmostPrompt":[
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
],
|
||||
# Translator
|
||||
"opus-mt-zh-en": [
|
||||
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
# IP-Adapter
|
||||
"IP-Adapter-SD": [
|
||||
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
||||
],
|
||||
"IP-Adapter-SDXL": [
|
||||
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
@@ -134,6 +279,40 @@ preset_models_on_huggingface = {
|
||||
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
],
|
||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||
"file_list": [
|
||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# CogVideo
|
||||
"CogVideoX-5B": [
|
||||
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||
],
|
||||
# Stable Diffusion 3.5
|
||||
"StableDiffusion3.5-large": [
|
||||
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
}
|
||||
preset_models_on_modelscope = {
|
||||
# Hunyuan DiT
|
||||
@@ -151,6 +330,9 @@ preset_models_on_modelscope = {
|
||||
"ExVideo-SVD-128f-v1": [
|
||||
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
||||
],
|
||||
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
||||
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
||||
],
|
||||
# Stable Diffusion
|
||||
"StableDiffusion_v15": [
|
||||
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
||||
@@ -209,6 +391,24 @@ preset_models_on_modelscope = {
|
||||
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
||||
],
|
||||
"Annotators:Depth": [
|
||||
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Softedge": [
|
||||
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Lineart": [
|
||||
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Normal": [
|
||||
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
||||
],
|
||||
"Annotators:Openpose": [
|
||||
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
||||
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
||||
],
|
||||
# AnimateDiff
|
||||
"AnimateDiff_v2": [
|
||||
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
||||
@@ -221,48 +421,67 @@ preset_models_on_modelscope = {
|
||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# Qwen Prompt
|
||||
"QwenPrompt": [
|
||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
],
|
||||
"QwenPrompt": {
|
||||
"file_list": [
|
||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/QwenPrompt/qwen2-1.5b-instruct",
|
||||
],
|
||||
},
|
||||
# Beautiful Prompt
|
||||
"BeautifulPrompt": [
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
"BeautifulPrompt": {
|
||||
"file_list": [
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
||||
],
|
||||
},
|
||||
# Omost prompt
|
||||
"OmostPrompt":[
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
],
|
||||
|
||||
"OmostPrompt": {
|
||||
"file_list": [
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
||||
],
|
||||
},
|
||||
# Translator
|
||||
"opus-mt-zh-en": [
|
||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
"opus-mt-zh-en": {
|
||||
"file_list": [
|
||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/translator/opus-mt-zh-en",
|
||||
],
|
||||
},
|
||||
# IP-Adapter
|
||||
"IP-Adapter-SD": [
|
||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||
@@ -273,32 +492,99 @@ preset_models_on_modelscope = {
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
"Kolors": {
|
||||
"file_list": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/kolors/Kolors/text_encoder",
|
||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
},
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||
],
|
||||
# FLUX
|
||||
"FLUX.1-dev": [
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
"FLUX.1-dev": {
|
||||
"file_list": [
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
],
|
||||
},
|
||||
"FLUX.1-schnell": {
|
||||
"file_list": [
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
||||
],
|
||||
},
|
||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
||||
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
||||
],
|
||||
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
||||
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
||||
],
|
||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
||||
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
||||
],
|
||||
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
||||
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
||||
],
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
||||
],
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
||||
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
||||
],
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
||||
],
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
||||
],
|
||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||
"file_list": [
|
||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
# ESRGAN
|
||||
"ESRGAN_x4": [
|
||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||
@@ -307,23 +593,103 @@ preset_models_on_modelscope = {
|
||||
"RIFE": [
|
||||
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
||||
],
|
||||
# Omnigen
|
||||
"OmniGen-v1": {
|
||||
"file_list": [
|
||||
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
||||
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
||||
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
||||
"models/OmniGen/OmniGen-v1/model.safetensors",
|
||||
]
|
||||
},
|
||||
# CogVideo
|
||||
"CogVideoX-5B": [
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||
"CogVideoX-5B": {
|
||||
"file_list": [
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/CogVideo/CogVideoX-5b/text_encoder",
|
||||
"models/CogVideo/CogVideoX-5b/transformer",
|
||||
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
},
|
||||
# Stable Diffusion 3.5
|
||||
"StableDiffusion3.5-large": [
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
"StableDiffusion3.5-medium": [
|
||||
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
"StableDiffusion3.5-large-turbo": [
|
||||
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
"HunyuanVideo":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
"HunyuanVideo-fp8":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
||||
],
|
||||
},
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
"stable-video-diffusion-img2vid-xt",
|
||||
"ExVideo-SVD-128f-v1",
|
||||
"ExVideo-CogVideoX-LoRA-129f-v1",
|
||||
"StableDiffusion_v15",
|
||||
"DreamShaper_8",
|
||||
"AingDiffusion_v12",
|
||||
@@ -349,10 +715,30 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"SDXL-vae-fp16-fix",
|
||||
"ControlNet_union_sdxl_promax",
|
||||
"FLUX.1-dev",
|
||||
"FLUX.1-schnell",
|
||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
||||
"jasperai/Flux.1-dev-Controlnet-Depth",
|
||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
||||
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||
"QwenPrompt",
|
||||
"OmostPrompt",
|
||||
"ESRGAN_x4",
|
||||
"RIFE",
|
||||
"OmniGen-v1",
|
||||
"CogVideoX-5B",
|
||||
]
|
||||
"Annotators:Depth",
|
||||
"Annotators:Softedge",
|
||||
"Annotators:Lineart",
|
||||
"Annotators:Normal",
|
||||
"Annotators:Openpose",
|
||||
"StableDiffusion3.5-large",
|
||||
"StableDiffusion3.5-medium",
|
||||
"HunyuanVideo",
|
||||
"HunyuanVideo-fp8",
|
||||
]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
||||
from .processors import Annotator
|
||||
|
||||
@@ -4,10 +4,11 @@ from .processors import Processor_id
|
||||
|
||||
|
||||
class ControlNetConfigUnit:
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
||||
self.processor_id = processor_id
|
||||
self.model_path = model_path
|
||||
self.scale = scale
|
||||
self.skip_processor = skip_processor
|
||||
|
||||
|
||||
class ControlNetUnit:
|
||||
@@ -30,6 +31,8 @@ class MultiControlNetManager:
|
||||
def to(self, device):
|
||||
for model in self.models:
|
||||
model.to(device)
|
||||
for processor in self.processors:
|
||||
processor.to(device)
|
||||
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
@@ -60,3 +63,29 @@ class MultiControlNetManager:
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
return res_stack
|
||||
|
||||
|
||||
class FluxMultiControlNetManager(MultiControlNetManager):
|
||||
def __init__(self, controlnet_units=[]):
|
||||
super().__init__(controlnet_units=controlnet_units)
|
||||
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
processed_image = [processor(image) for processor in self.processors]
|
||||
else:
|
||||
processed_image = [self.processors[processor_id](image)]
|
||||
return processed_image
|
||||
|
||||
def __call__(self, conditionings, **kwargs):
|
||||
res_stack, single_res_stack = None, None
|
||||
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
||||
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
single_res_stack_ = [res * scale for res in single_res_stack_]
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
single_res_stack = single_res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
||||
return res_stack, single_res_stack
|
||||
|
||||
@@ -3,37 +3,47 @@ import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector, NormalBaeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||
if not skip_processor:
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "normal":
|
||||
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
||||
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
self.processor = None
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
def to(self,device):
|
||||
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
|
||||
|
||||
def __call__(self, image):
|
||||
self.processor.model.to(device)
|
||||
|
||||
def __call__(self, image, mask=None):
|
||||
width, height = image.size
|
||||
if self.processor_id == "openpose":
|
||||
kwargs = {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, os
|
||||
import torch, os, torchvision
|
||||
from torchvision import transforms
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
@@ -11,9 +11,10 @@ class TextImageDataset(torch.utils.data.Dataset):
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
@@ -27,6 +28,11 @@ class TextImageDataset(torch.utils.data.Dataset):
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
target_height, target_width = self.height, self.width
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
shape = [round(height*scale),round(width*scale)]
|
||||
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
@@ -135,8 +135,8 @@ class VideoData:
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
|
||||
@@ -107,6 +107,12 @@ class ESRGAN(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
is_single_image = True
|
||||
else:
|
||||
is_single_image = False
|
||||
|
||||
# Preprocess
|
||||
input_tensor = self.process_images(images)
|
||||
|
||||
@@ -126,4 +132,6 @@ class ESRGAN(torch.nn.Module):
|
||||
|
||||
# To images
|
||||
output_images = self.decode_images(output_tensor)
|
||||
if is_single_image:
|
||||
output_images = output_images[0]
|
||||
return output_images
|
||||
|
||||
@@ -283,7 +283,7 @@ class CogDiT(torch.nn.Module):
|
||||
return value
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
|
||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
||||
@@ -298,8 +298,21 @@ class CogDiT(torch.nn.Module):
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
@@ -8,27 +8,32 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, file_name)
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
@@ -45,16 +50,47 @@ website_to_download_fn = {
|
||||
}
|
||||
|
||||
|
||||
def download_customized_models(
|
||||
model_id,
|
||||
origin_file_path,
|
||||
local_dir,
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
downloaded_files = []
|
||||
for website in downloading_priority:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
|
||||
|
||||
def download_models(
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
print(f"Downloading models: {model_id_list}")
|
||||
downloaded_files = []
|
||||
load_files = []
|
||||
|
||||
for model_id in model_id_list:
|
||||
for website in downloading_priority:
|
||||
if model_id in website_to_preset_models[website]:
|
||||
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
||||
|
||||
# Parse model metadata
|
||||
model_metadata = website_to_preset_models[website][model_id]
|
||||
if isinstance(model_metadata, list):
|
||||
file_data = model_metadata
|
||||
else:
|
||||
file_data = model_metadata.get("file_list", [])
|
||||
|
||||
# Try downloading the model from this website.
|
||||
model_files = []
|
||||
for model_id, origin_file_path, local_dir in file_data:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
@@ -63,4 +99,13 @@ def download_models(
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
model_files.append(file_to_download)
|
||||
|
||||
# If the model is successfully downloaded, break.
|
||||
if len(model_files) > 0:
|
||||
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
||||
model_files = model_metadata["load_path"]
|
||||
load_files.extend(model_files)
|
||||
break
|
||||
|
||||
return load_files
|
||||
|
||||
327
diffsynth/models/flux_controlnet.py
Normal file
327
diffsynth/models/flux_controlnet.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||
from .utils import hash_state_dict_keys, init_weights_on_device
|
||||
|
||||
|
||||
|
||||
class FluxControlNet(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
||||
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
||||
|
||||
self.mode_dict = mode_dict
|
||||
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
||||
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
||||
if len(res_stack) == 0:
|
||||
return [torch.zeros_like(hidden_states)] * num_blocks
|
||||
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
||||
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
||||
return aligned_res_stack
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
controlnet_conditioning,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
processor_id=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
||||
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
||||
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
||||
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
||||
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
||||
|
||||
controlnet_res_stack = []
|
||||
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_res_stack.append(controlnet_block(hidden_states))
|
||||
|
||||
controlnet_single_res_stack = []
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
||||
|
||||
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
||||
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
||||
|
||||
return controlnet_res_stack, controlnet_single_res_stack
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxControlNetStateDictConverter()
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
bias = None
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device)
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class QLinear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class QRMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
class QEmbedding(torch.nn.Embedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight= cast_weight(self,input)
|
||||
return torch.nn.functional.embedding(
|
||||
input, weight, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module,quantized_layer.QRMSNorm):
|
||||
continue
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
||||
new_layer.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_layer.bias = module.bias
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
if hasattr(module,"quantized"):
|
||||
continue
|
||||
module.quantized= True
|
||||
new_layer = quantized_layer.QRMSNorm(module)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module,torch.nn.Embedding):
|
||||
rows, cols = module.weight.shape
|
||||
new_layer = quantized_layer.QEmbedding(
|
||||
num_embeddings=rows,
|
||||
embedding_dim=cols,
|
||||
_weight=module.weight,
|
||||
# _freeze=module.freeze,
|
||||
padding_idx=module.padding_idx,
|
||||
max_norm=module.max_norm,
|
||||
norm_type=module.norm_type,
|
||||
scale_grad_by_freq=module.scale_grad_by_freq,
|
||||
sparse=module.sparse)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module)
|
||||
|
||||
replace_layer(self)
|
||||
|
||||
|
||||
|
||||
class FluxControlNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
hash_value = hash_state_dict_keys(state_dict)
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||
state_dict_[name],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||
state_dict_.pop(name)
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
||||
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
||||
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
||||
extra_kwargs = {"num_single_blocks": 0}
|
||||
elif hash_value == "52357cb26250681367488a8954c271e8":
|
||||
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
||||
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
||||
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
return state_dict_, extra_kwargs
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,8 +1,15 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
from .utils import init_weights_on_device
|
||||
|
||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RoPEEmbedding(torch.nn.Module):
|
||||
@@ -33,23 +40,8 @@ class RoPEEmbedding(torch.nn.Module):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class FluxJointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||
@@ -78,8 +70,7 @@ class FluxJointAttention(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
# Part A
|
||||
@@ -100,17 +91,19 @@ class FluxJointAttention(torch.nn.Module):
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxJointTransformerBlock(torch.nn.Module):
|
||||
@@ -136,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
@@ -154,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxSingleAttention(torch.nn.Module):
|
||||
@@ -191,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormSingle(torch.nn.Module):
|
||||
@@ -207,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module):
|
||||
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
|
||||
|
||||
class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
@@ -232,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb):
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -242,27 +235,29 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
|
||||
q, k = self.apply_rope(q, k, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb)
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||
|
||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
||||
hidden_states_a = residual + hidden_states_a
|
||||
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(torch.nn.Module):
|
||||
@@ -281,11 +276,11 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
@@ -305,7 +300,7 @@ class FluxDiT(torch.nn.Module):
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def prepare_image_ids(self, latents):
|
||||
batch_size, _, height, width = latents.shape
|
||||
@@ -322,7 +317,7 @@ class FluxDiT(torch.nn.Module):
|
||||
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
|
||||
def tiled_forward(
|
||||
self,
|
||||
@@ -343,11 +338,75 @@ class FluxDiT(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
total_seq_len = N * prompt_seq_len + image_seq_len
|
||||
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
||||
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
||||
|
||||
image_start = N * prompt_seq_len
|
||||
image_end = N * prompt_seq_len + image_seq_len
|
||||
# prompt-image mask
|
||||
for i in range(N):
|
||||
prompt_start = i * prompt_seq_len
|
||||
prompt_end = (i + 1) * prompt_seq_len
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
||||
# prompt update with image
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# image update with prompt
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
# prompt-prompt mask
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i != j:
|
||||
prompt_start_i = i * prompt_seq_len
|
||||
prompt_end_i = (i + 1) * prompt_seq_len
|
||||
prompt_start_j = j * prompt_seq_len
|
||||
prompt_end_j = (j + 1) * prompt_seq_len
|
||||
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
||||
|
||||
attention_mask = attention_mask.float()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
return attention_mask
|
||||
|
||||
|
||||
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
|
||||
repeat_dim = hidden_states.shape[1]
|
||||
max_masks = 0
|
||||
attention_mask = None
|
||||
prompt_embs = [prompt_emb]
|
||||
if entity_masks is not None:
|
||||
# entity_masks
|
||||
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
# global mask
|
||||
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
entity_masks = entity_masks + [global_mask] # append global to last
|
||||
# attention mask
|
||||
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
||||
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
# embds: n_masks * b * seq * d
|
||||
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
prompt_embs = local_embs + prompt_embs # append global to last
|
||||
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
||||
prompt_emb = torch.cat(prompt_embs, dim=1)
|
||||
|
||||
# positional embedding
|
||||
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
return prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
@@ -358,45 +417,51 @@ class FluxDiT(torch.nn.Module):
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
||||
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
||||
+ self.pooled_text_embedder(pooled_prompt_emb)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||
@@ -406,10 +471,87 @@ class FluxDiT(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
bias = None
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device)
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
||||
new_layer.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_layer.bias = module.bias
|
||||
# del module
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
if hasattr(module,"quantized"):
|
||||
continue
|
||||
module.quantized= True
|
||||
new_layer = quantized_layer.RMSNorm(module)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module)
|
||||
|
||||
replace_layer(self)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxDiTStateDictConverter:
|
||||
@@ -513,7 +655,7 @@ class FluxDiTStateDictConverter:
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
@@ -574,6 +716,8 @@ class FluxDiTStateDictConverter:
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("model.diffusion_model."):
|
||||
name = name[len("model.diffusion_model."):]
|
||||
names = name.split(".")
|
||||
if name in rename_dict:
|
||||
rename = rename_dict[name]
|
||||
@@ -589,5 +733,7 @@ class FluxDiTStateDictConverter:
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
pass
|
||||
return state_dict_
|
||||
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
94
diffsynth/models/flux_ipadapter.py
Normal file
94
diffsynth/models/flux_ipadapter.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .sd3_dit import RMSNorm
|
||||
from transformers import CLIPImageProcessor
|
||||
import torch
|
||||
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, id_embeds):
|
||||
x = self.proj(id_embeds)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class IpAdapterModule(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = attention_head_dim
|
||||
output_dim = num_attention_heads * attention_head_dim
|
||||
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
||||
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch_size = hidden_states.shape[0]
|
||||
# ip_k
|
||||
ip_k = self.to_k_ip(hidden_states)
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_k = self.norm_added_k(ip_k)
|
||||
# ip_v
|
||||
ip_v = self.to_v_ip(hidden_states)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
return ip_k, ip_v
|
||||
|
||||
|
||||
class FluxIpAdapter(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
||||
super().__init__()
|
||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
||||
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
||||
self.set_adapter()
|
||||
|
||||
def set_adapter(self):
|
||||
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||
ip_kv_dict = {}
|
||||
for block_id in self.call_block_id:
|
||||
ipadapter_id = self.call_block_id[block_id]
|
||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||
ip_kv_dict[block_id] = {
|
||||
"ip_k": ip_k,
|
||||
"ip_v": ip_v,
|
||||
"scale": scale
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
class FluxIpAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict["ip_adapter"]:
|
||||
name_ = 'ipadapter_modules.' + name
|
||||
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||
for name in state_dict["image_proj"]:
|
||||
name_ = "image_proj." + name
|
||||
state_dict_[name_] = state_dict["image_proj"][name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -3,26 +3,6 @@ from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
|
||||
|
||||
class FluxTextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
hidden_states = embeds
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
return embeds, pooled_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxTextEncoder1StateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2(T5EncoderModel):
|
||||
def __init__(self, config):
|
||||
@@ -40,47 +20,6 @@ class FluxTextEncoder2(T5EncoderModel):
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder1StateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2StateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
885
diffsynth/models/hunyuan_video_dit.py
Normal file
885
diffsynth/models/hunyuan_video_dit.py
Normal file
@@ -0,0 +1,885 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
"""
|
||||
Get n-D meshgrid with start, stop and num.
|
||||
|
||||
Args:
|
||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||
n-tuples.
|
||||
*args: See above.
|
||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
grid (np.ndarray): [dim, ...]
|
||||
"""
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[torch.FloatTensor, int],
|
||||
theta: float = 10000.0,
|
||||
use_real: bool = False,
|
||||
theta_rescale_factor: float = 1.0,
|
||||
interpolation_factor: float = 1.0,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||
|
||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
if theta_rescale_factor != 1.0:
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||
) # [D/2]
|
||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(
|
||||
torch.ones_like(freqs), freqs
|
||||
) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
start,
|
||||
*args,
|
||||
theta=10000.0,
|
||||
use_real=False,
|
||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
"""
|
||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||
|
||||
Args:
|
||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
*args: See above.
|
||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||
part and an imaginary part separately.
|
||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
pos_embed (torch.Tensor): [HW, D/2]
|
||||
"""
|
||||
|
||||
grid = get_meshgrid_nd(
|
||||
start, *args, dim=len(rope_dim_list)
|
||||
) # [3, W, H, D] / [2, W, H]
|
||||
|
||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||
assert len(theta_rescale_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||
assert len(interpolation_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
# use 1/ndim of dimensions to encode grid_axis
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i],
|
||||
grid[i].reshape(-1),
|
||||
theta,
|
||||
use_real=use_real,
|
||||
theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
) # 2 x [WHD, rope_dim_list[i]]
|
||||
embs.append(emb)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||
return emb
|
||||
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
[16, 56, 56],
|
||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
||||
theta=256,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1,
|
||||
)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, num_heads=24):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * 4),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size * 4, hidden_size)
|
||||
)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
|
||||
)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn = rearrange(attn, "B H L D -> B L (H D)")
|
||||
|
||||
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
super().__init__()
|
||||
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.c_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_channels, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
|
||||
|
||||
def forward(self, x, t, mask=None):
|
||||
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
|
||||
|
||||
mask_float = mask.float().unsqueeze(-1)
|
||||
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
|
||||
x = self.input_embedder(x)
|
||||
|
||||
mask = mask.to(device=x.device, dtype=torch.bool)
|
||||
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
|
||||
mask = mask & mask.transpose(2, 3)
|
||||
mask[:, :, :, 0] = True
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
super().__init__()
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
elif scale is None:
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
x: torch.Tensor,
|
||||
head_first=False,
|
||||
):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = (
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis,
|
||||
head_first: bool = False,
|
||||
):
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
# real * cos - imag * sin
|
||||
# imag * cos + real * sin
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
else:
|
||||
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
||||
xq_ = torch.view_as_complex(
|
||||
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
||||
xq.device
|
||||
) # [S, D//2] --> [1, S, 1, D//2]
|
||||
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
||||
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
xk_ = torch.view_as_complex(
|
||||
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = x.transpose(1, 2).flatten(2, 3)
|
||||
return x
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size)
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
|
||||
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
|
||||
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = torch.nn.GELU(approximate="tanh")
|
||||
self.modulation = ModulateDiT(hidden_size, factor=3)
|
||||
|
||||
def forward(self, x, vec, freqs_cis=None, txt_len=256):
|
||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
q = torch.cat((q_a, q_b), dim=1)
|
||||
k = torch.cat((k_a, k_b), dim=1)
|
||||
|
||||
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
|
||||
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size, factor=3)
|
||||
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FinalLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
|
||||
|
||||
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.vector_in = torch.nn.Sequential(
|
||||
torch.nn.Linear(768, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
|
||||
# TODO: remove these parameters
|
||||
self.dtype = torch.bfloat16
|
||||
self.patch_size = [1, 2, 2]
|
||||
self.hidden_size = 3072
|
||||
self.heads_num = 24
|
||||
self.rope_dim_list = [16, 56, 56]
|
||||
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
self.to(self.cold_device)
|
||||
|
||||
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def prepare_freqs(self, latents):
|
||||
return HunyuanVideoRope(latents)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img = x[:, :-256]
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def block_forward_(self, x, i, j, dtype, device):
|
||||
weight_ = cast_to(
|
||||
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
if self.bias is None or i > 0:
|
||||
bias_ = None
|
||||
else:
|
||||
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
|
||||
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
|
||||
y_ = torch.nn.functional.linear(x_, weight_, bias_)
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
for i in range((self.in_features + self.block_size - 1) // self.block_size):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.module.weight is not None:
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(
|
||||
module.in_features, module.out_features, bias=module.bias is not None,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.Conv3d):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Conv3d(
|
||||
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
new_layer = quantized_layer.RMSNorm(
|
||||
module,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.LayerNorm):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.LayerNorm(
|
||||
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module, dtype=dtype, device=device)
|
||||
|
||||
replace_layer(self, dtype=dtype, device=device)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
"img_in.proj": "img_in.proj",
|
||||
"time_in.mlp.0": "time_in.timestep_embedder.0",
|
||||
"time_in.mlp.2": "time_in.timestep_embedder.2",
|
||||
"vector_in.in_layer": "vector_in.0",
|
||||
"vector_in.out_layer": "vector_in.2",
|
||||
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
|
||||
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
|
||||
"txt_in.input_embedder": "txt_in.input_embedder",
|
||||
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
|
||||
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
|
||||
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
|
||||
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
|
||||
"final_layer.linear": "final_layer.linear",
|
||||
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
|
||||
}
|
||||
txt_suffix_dict = {
|
||||
"norm1": "norm1",
|
||||
"self_attn_qkv": "self_attn_qkv",
|
||||
"self_attn_proj": "self_attn_proj",
|
||||
"norm2": "norm2",
|
||||
"mlp.fc1": "mlp.0",
|
||||
"mlp.fc2": "mlp.2",
|
||||
"adaLN_modulation.1": "adaLN_modulation.1",
|
||||
}
|
||||
double_suffix_dict = {
|
||||
"img_mod.linear": "component_a.mod.linear",
|
||||
"img_attn_qkv": "component_a.to_qkv",
|
||||
"img_attn_q_norm": "component_a.norm_q",
|
||||
"img_attn_k_norm": "component_a.norm_k",
|
||||
"img_attn_proj": "component_a.to_out",
|
||||
"img_mlp.fc1": "component_a.ff.0",
|
||||
"img_mlp.fc2": "component_a.ff.2",
|
||||
"txt_mod.linear": "component_b.mod.linear",
|
||||
"txt_attn_qkv": "component_b.to_qkv",
|
||||
"txt_attn_q_norm": "component_b.norm_q",
|
||||
"txt_attn_k_norm": "component_b.norm_k",
|
||||
"txt_attn_proj": "component_b.to_out",
|
||||
"txt_mlp.fc1": "component_b.ff.0",
|
||||
"txt_mlp.fc2": "component_b.ff.2",
|
||||
}
|
||||
single_suffix_dict = {
|
||||
"linear1": ["to_qkv", "ff.0"],
|
||||
"linear2": ["to_out", "ff.2"],
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"modulation.linear": "mod.linear",
|
||||
}
|
||||
# single_suffix_dict = {
|
||||
# "linear1": "linear1",
|
||||
# "linear2": "linear2",
|
||||
# "q_norm": "q_norm",
|
||||
# "k_norm": "k_norm",
|
||||
# "modulation.linear": "modulation.linear",
|
||||
# }
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
direct_name = ".".join(names[:-1])
|
||||
if direct_name in direct_dict:
|
||||
name_ = direct_dict[direct_name] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "double_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "single_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
if isinstance(single_suffix_dict[suffix], list):
|
||||
if suffix == "linear1":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
elif suffix == "linear2":
|
||||
if names[-1] == "weight":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
else:
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "txt_in":
|
||||
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
|
||||
suffix = ".".join(names[4:-1])
|
||||
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
return state_dict_
|
||||
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
if self.auto_offload:
|
||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
@@ -0,0 +1,507 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
|
||||
) # W, H, T
|
||||
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class UpsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.upsample_factor = upsample_factor
|
||||
self.conv = None
|
||||
if use_conv:
|
||||
kernel_size = 3 if kernel_size is None else kernel_size
|
||||
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# interpolate
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
||||
if T > 1:
|
||||
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
||||
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
|
||||
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
if self.conv:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
|
||||
super().__init__()
|
||||
self.pre_norm = True
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
# conv1
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
# conv2
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
# shortcut
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = (self.conv_shortcut(input_tensor))
|
||||
# shortcut and scale
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
|
||||
seq_len = n_frame * n_hw
|
||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
||||
for i in range(seq_len):
|
||||
i_frame = i // n_hw
|
||||
mask[i, :(i_frame + 1) * n_hw] = 0
|
||||
if batch_size is not None:
|
||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_groups=32,
|
||||
dropout=0.0,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
residual_connection=True):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.residual_connection = residual_connection
|
||||
dim_inner = head_dim * num_heads
|
||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, input_tensor, attn_mask=None):
|
||||
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(hidden_states)
|
||||
v = self.to_v(hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
|
||||
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
if self.residual_connection:
|
||||
output_tensor = input_tensor + hidden_states
|
||||
return output_tensor
|
||||
|
||||
|
||||
class UNetMidBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
|
||||
super().__init__()
|
||||
resnets = [
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
attention_head_dim = attention_head_dim or in_channels
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
num_heads=in_channels // attention_head_dim,
|
||||
head_dim=attention_head_dim,
|
||||
num_groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
residual_connection=True,
|
||||
))
|
||||
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
))
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
||||
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
||||
hidden_states = attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpDecoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_upsample=True,
|
||||
upsample_scale_factor=(2, 2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.upsamplers = None
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
UpsampleCausal3D(
|
||||
out_channels,
|
||||
use_conv=True,
|
||||
out_channels=out_channels,
|
||||
upsample_factor=upsample_scale_factor,
|
||||
)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DecoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
||||
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
||||
|
||||
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
||||
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
||||
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
||||
|
||||
up_block = UpDecoderBlockCausal3D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
||||
upsample_scale_factor=upsample_scale_factor,
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = DecoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, latents):
|
||||
latents = latents / self.scaling_factor
|
||||
latents = self.post_quant_conv(latents)
|
||||
dec = self.decoder(latents)
|
||||
return dec
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.post_quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.post_quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t * 4 + 1
|
||||
target_h = h * 8
|
||||
target_w = w * 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
||||
latents = latents.to(self.post_quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
307
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
307
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||
|
||||
|
||||
class DownsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
||||
super().__init__()
|
||||
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_downsample=True,
|
||||
downsample_stride=2,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
||||
out_channels,
|
||||
out_channels,
|
||||
stride=downsample_stride,
|
||||
)])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
||||
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
||||
|
||||
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
||||
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
||||
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
||||
down_block = DownEncoderBlockCausal3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
||||
downsample_stride=downsample_stride,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = EncoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, images):
|
||||
latents = self.encoder(images)
|
||||
latents = self.quant_conv(latents)
|
||||
latents = latents[:, :16]
|
||||
latents = latents * self.scaling_factor
|
||||
return latents
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t // 4 + 1
|
||||
target_h = h // 8
|
||||
target_w = w // 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
||||
latents = latents.to(self.quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -6,6 +6,8 @@ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sd3_dit import SD3DiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .cog_dit import CogDiT
|
||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||
|
||||
|
||||
|
||||
@@ -77,11 +79,19 @@ class LoRAFromCivitai:
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||
elif model_resource == "civitai":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||
if isinstance(state_dict_lora, tuple):
|
||||
state_dict_lora = state_dict_lora[0]
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
fp8=False
|
||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
||||
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
|
||||
fp8=True
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
if fp8:
|
||||
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
@@ -96,6 +106,8 @@ class LoRAFromCivitai:
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
if isinstance(state_dict_lora_, tuple):
|
||||
state_dict_lora_ = state_dict_lora_[0]
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
@@ -185,7 +197,7 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT]
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT]
|
||||
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||
@@ -248,5 +260,108 @@ class GeneralLoRAFromPeft:
|
||||
return None
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
||||
self.lora_prefix = ["diffusion_model.", "transformer."]
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
class FluxLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, alpha=1.0):
|
||||
prefix_rename_dict = {
|
||||
"single_blocks": "lora_unet_single_blocks",
|
||||
"blocks": "lora_unet_double_blocks",
|
||||
}
|
||||
middle_rename_dict = {
|
||||
"norm.linear": "modulation_lin",
|
||||
"to_qkv_mlp": "linear1",
|
||||
"proj_out": "linear2",
|
||||
|
||||
"norm1_a.linear": "img_mod_lin",
|
||||
"norm1_b.linear": "txt_mod_lin",
|
||||
"attn.a_to_qkv": "img_attn_qkv",
|
||||
"attn.b_to_qkv": "txt_attn_qkv",
|
||||
"attn.a_to_out": "img_attn_proj",
|
||||
"attn.b_to_out": "txt_attn_proj",
|
||||
"ff_a.0": "img_mlp_0",
|
||||
"ff_a.2": "img_mlp_2",
|
||||
"ff_b.0": "txt_mlp_0",
|
||||
"ff_b.2": "txt_mlp_2",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"lora_B.weight": "lora_up.weight",
|
||||
"lora_A.weight": "lora_down.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
||||
names.pop(-2)
|
||||
prefix = names[0]
|
||||
middle = ".".join(names[2:-2])
|
||||
suffix = ".".join(names[-2:])
|
||||
block_id = names[1]
|
||||
if middle not in middle_rename_dict:
|
||||
continue
|
||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
||||
state_dict_[rename] = param
|
||||
if rename.endswith("lora_up.weight"):
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
|
||||
return state_dict_
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict):
|
||||
rename_dict = {
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
}
|
||||
def guess_block_id(name):
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
return None, None
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
block_id, source_name = guess_block_id(name)
|
||||
if source_name in rename_dict:
|
||||
target_name = rename_dict[source_name]
|
||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||
state_dict_[target_name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()]
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import os, torch, hashlib, json, importlib
|
||||
from safetensors import safe_open
|
||||
from torch import Tensor
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import os, torch, json, importlib
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, Preset_model_id, Preset_model_website
|
||||
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
@@ -38,10 +35,13 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from .flux_text_encoder import FluxTextEncoder2
|
||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from .flux_ipadapter import FluxIpAdapter
|
||||
|
||||
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from .cog_dit import CogDiT
|
||||
@@ -50,45 +50,7 @@ from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict
|
||||
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
||||
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
@@ -106,8 +68,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
||||
model.load_state_dict(model_state_dict)
|
||||
with init_weights_on_device():
|
||||
model= model_class(**extra_kwargs)
|
||||
model.load_state_dict(model_state_dict, assign=True)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
@@ -116,7 +80,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
else:
|
||||
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
try:
|
||||
@@ -191,7 +158,7 @@ class ModelDetectorFromSingleFile:
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
@@ -233,7 +200,7 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
@@ -276,7 +243,7 @@ class ModelDetectorFromHuggingfaceFolder:
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isfile(file_path):
|
||||
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
@@ -317,7 +284,7 @@ class ModelDetectorFromPatchedSingleFile:
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
@@ -402,22 +369,32 @@ class ModelManager:
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
break
|
||||
if isinstance(file_path, list):
|
||||
for file_path_ in file_path:
|
||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||
else:
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
break
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None):
|
||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||
print(f"Loading models from: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
if device is None: device = self.device
|
||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for path in file_path:
|
||||
state_dict.update(load_state_dict(path))
|
||||
elif os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
@@ -425,7 +402,7 @@ class ModelManager:
|
||||
if model_detector.match(file_path, state_dict):
|
||||
model_names, models = model_detector.load(
|
||||
file_path, state_dict,
|
||||
device=self.device, torch_dtype=self.torch_dtype,
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
allowed_model_names=model_names, model_manager=self
|
||||
)
|
||||
for model_name, model in zip(model_names, models):
|
||||
@@ -438,9 +415,9 @@ class ModelManager:
|
||||
print(f" We cannot detect the model type. No models are loaded.")
|
||||
|
||||
|
||||
def load_models(self, file_path_list, model_names=None):
|
||||
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, model_names)
|
||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
||||
|
||||
803
diffsynth/models/omnigen.py
Normal file
803
diffsynth/models/omnigen.py
Normal file
@@ -0,0 +1,803 @@
|
||||
# The code is revised from DiT
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
from safetensors.torch import load_file
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch.utils.checkpoint
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers import Phi3Config, Phi3Model
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Phi3Transformer(Phi3Model):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
||||
We only modified the attention mask
|
||||
Args:
|
||||
config: Phi3Config
|
||||
"""
|
||||
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
||||
"Starts prefetching the next layer cache"
|
||||
with torch.cuda.stream(self.prefetch_stream):
|
||||
# Prefetch next layer tensors to GPU
|
||||
for name, param in self.layers[layer_idx].named_parameters():
|
||||
param.data = param.data.to(device, non_blocking=True)
|
||||
|
||||
def evict_previous_layer(self, layer_idx: int):
|
||||
"Moves the previous layer cache to the CPU"
|
||||
prev_layer_idx = layer_idx - 1
|
||||
for name, param in self.layers[prev_layer_idx].named_parameters():
|
||||
param.data = param.data.to("cpu", non_blocking=True)
|
||||
|
||||
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
||||
# init stream
|
||||
if not hasattr(self, "prefetch_stream"):
|
||||
self.prefetch_stream = torch.cuda.Stream()
|
||||
|
||||
# delete previous layer
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.evict_previous_layer(layer_idx)
|
||||
|
||||
# make sure the current layer is ready
|
||||
torch.cuda.synchronize(self.prefetch_stream)
|
||||
|
||||
# load next layer
|
||||
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
offload_model: Optional[bool] = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
|
||||
# if inputs_embeds is None:
|
||||
# inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# if cache_position is None:
|
||||
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
# cache_position = torch.arange(
|
||||
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
# )
|
||||
# if position_ids is None:
|
||||
# position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 3:
|
||||
dtype = inputs_embeds.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
attention_mask = (1 - attention_mask) * min_dtype
|
||||
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
||||
else:
|
||||
raise Exception("attention_mask parameter was unavailable or invalid")
|
||||
# causal_mask = self._update_causal_mask(
|
||||
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
# )
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
layer_idx = -1
|
||||
for decoder_layer in self.layers:
|
||||
layer_idx += 1
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
if offload_model and not self.training:
|
||||
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
print('************')
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype=torch.float32):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
|
||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class PatchEmbedMR(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_chans: int = 4,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
class OmniGenOriginalModel(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
transformer_config: Phi3Config,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
pe_interpolation: float = 1.0,
|
||||
pos_embed_max_size: int = 192,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
hidden_size = transformer_config.hidden_size
|
||||
|
||||
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||
|
||||
self.time_token = TimestepEmbedder(hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
||||
self.pe_interpolation = pe_interpolation
|
||||
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
self.llm = Phi3Transformer(config=transformer_config)
|
||||
self.llm.config.use_cache = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name):
|
||||
if not os.path.exists(model_name):
|
||||
cache_folder = os.getenv('HF_HUB_CACHE')
|
||||
model_name = snapshot_download(repo_id=model_name,
|
||||
cache_dir=cache_folder,
|
||||
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
||||
config = Phi3Config.from_pretrained(model_name)
|
||||
model = cls(config)
|
||||
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
||||
print("Loading safetensors")
|
||||
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
||||
else:
|
||||
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
||||
model.load_state_dict(ckpt)
|
||||
return model
|
||||
|
||||
def initialize_weights(self):
|
||||
assert not hasattr(self, "llama")
|
||||
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
w = self.input_x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
||||
return imgs
|
||||
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
"""Crops positional embeddings for SD3 compatibility."""
|
||||
if self.pos_embed_max_size is None:
|
||||
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
||||
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
if height > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
if width > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
||||
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||
return spatial_pos_embed
|
||||
|
||||
|
||||
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
||||
if isinstance(latents, list):
|
||||
return_list = False
|
||||
if padding_latent is None:
|
||||
padding_latent = [None] * len(latents)
|
||||
return_list = True
|
||||
patched_latents, num_tokens, shapes = [], [], []
|
||||
for latent, padding in zip(latents, padding_latent):
|
||||
height, width = latent.shape[-2:]
|
||||
if is_input_images:
|
||||
latent = self.input_x_embedder(latent)
|
||||
else:
|
||||
latent = self.x_embedder(latent)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
latent = latent + pos_embed
|
||||
if padding is not None:
|
||||
latent = torch.cat([latent, padding], dim=-2)
|
||||
patched_latents.append(latent)
|
||||
|
||||
num_tokens.append(pos_embed.size(1))
|
||||
shapes.append([height, width])
|
||||
if not return_list:
|
||||
latents = torch.cat(patched_latents, dim=0)
|
||||
else:
|
||||
latents = patched_latents
|
||||
else:
|
||||
height, width = latents.shape[-2:]
|
||||
if is_input_images:
|
||||
latents = self.input_x_embedder(latents)
|
||||
else:
|
||||
latents = self.x_embedder(latents)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
latents = latents + pos_embed
|
||||
num_tokens = latents.size(1)
|
||||
shapes = [height, width]
|
||||
return latents, num_tokens, shapes
|
||||
|
||||
|
||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||
"""
|
||||
|
||||
"""
|
||||
input_is_list = isinstance(x, list)
|
||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||
|
||||
if input_img_latents is not None:
|
||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||
if input_ids is not None:
|
||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||
input_img_inx = 0
|
||||
for b_inx in input_image_sizes.keys():
|
||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||
input_img_inx += 1
|
||||
if input_img_latents is not None:
|
||||
assert input_img_inx == len(input_latents)
|
||||
|
||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||
else:
|
||||
input_emb = torch.cat([time_token, x], dim=1)
|
||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||
if input_is_list:
|
||||
image_embedding = output[:, -max(num_tokens):]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = []
|
||||
for i in range(x.size(0)):
|
||||
latent = x[i:i+1, :num_tokens[i]]
|
||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||
latents.append(latent)
|
||||
else:
|
||||
image_embedding = output[:, -num_tokens:]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||
|
||||
if return_past_key_values:
|
||||
return latents, past_key_values
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
||||
if use_img_cfg:
|
||||
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
else:
|
||||
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
|
||||
return torch.cat(model_out, dim=0), past_key_values
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(attention_mask)
|
||||
|
||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||
timestep = timestep.to(x[0].dtype)
|
||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||
|
||||
model_out, pask_key_values = [], []
|
||||
for i in range(len(input_ids)):
|
||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||
model_out.append(temp_out)
|
||||
pask_key_values.append(temp_pask_key_values)
|
||||
|
||||
if len(model_out) == 3:
|
||||
cond, uncond, img_cond = model_out
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
elif len(model_out) == 2:
|
||||
cond, uncond = model_out
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
else:
|
||||
return model_out[0]
|
||||
|
||||
return torch.cat(model_out, dim=0), pask_key_values
|
||||
|
||||
|
||||
|
||||
class OmniGenTransformer(OmniGenOriginalModel):
|
||||
def __init__(self):
|
||||
config = {
|
||||
"_name_or_path": "Phi-3-vision-128k-instruct",
|
||||
"architectures": [
|
||||
"Phi3ForCausalLM"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 8192,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "phi3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 32,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"long_factor": [
|
||||
1.0299999713897705,
|
||||
1.0499999523162842,
|
||||
1.0499999523162842,
|
||||
1.0799999237060547,
|
||||
1.2299998998641968,
|
||||
1.2299998998641968,
|
||||
1.2999999523162842,
|
||||
1.4499999284744263,
|
||||
1.5999999046325684,
|
||||
1.6499998569488525,
|
||||
1.8999998569488525,
|
||||
2.859999895095825,
|
||||
3.68999981880188,
|
||||
5.419999599456787,
|
||||
5.489999771118164,
|
||||
5.489999771118164,
|
||||
9.09000015258789,
|
||||
11.579999923706055,
|
||||
15.65999984741211,
|
||||
15.769999504089355,
|
||||
15.789999961853027,
|
||||
18.360000610351562,
|
||||
21.989999771118164,
|
||||
23.079999923706055,
|
||||
30.009998321533203,
|
||||
32.35000228881836,
|
||||
32.590003967285156,
|
||||
35.56000518798828,
|
||||
39.95000457763672,
|
||||
53.840003967285156,
|
||||
56.20000457763672,
|
||||
57.95000457763672,
|
||||
59.29000473022461,
|
||||
59.77000427246094,
|
||||
59.920005798339844,
|
||||
61.190006256103516,
|
||||
61.96000671386719,
|
||||
62.50000762939453,
|
||||
63.3700065612793,
|
||||
63.48000717163086,
|
||||
63.48000717163086,
|
||||
63.66000747680664,
|
||||
63.850006103515625,
|
||||
64.08000946044922,
|
||||
64.760009765625,
|
||||
64.80001068115234,
|
||||
64.81001281738281,
|
||||
64.81001281738281
|
||||
],
|
||||
"short_factor": [
|
||||
1.05,
|
||||
1.05,
|
||||
1.05,
|
||||
1.1,
|
||||
1.1,
|
||||
1.1,
|
||||
1.2500000000000002,
|
||||
1.2500000000000002,
|
||||
1.4000000000000004,
|
||||
1.4500000000000004,
|
||||
1.5500000000000005,
|
||||
1.8500000000000008,
|
||||
1.9000000000000008,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.1000000000000005,
|
||||
2.1000000000000005,
|
||||
2.2,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3999999999999995,
|
||||
2.3999999999999995,
|
||||
2.6499999999999986,
|
||||
2.6999999999999984,
|
||||
2.8999999999999977,
|
||||
2.9499999999999975,
|
||||
3.049999999999997,
|
||||
3.049999999999997,
|
||||
3.049999999999997
|
||||
],
|
||||
"type": "su"
|
||||
},
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 131072,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.38.1",
|
||||
"use_cache": True,
|
||||
"vocab_size": 32064,
|
||||
"_attn_implementation": "sdpa"
|
||||
}
|
||||
config = Phi3Config(**config)
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||
input_is_list = isinstance(x, list)
|
||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||
|
||||
if input_img_latents is not None:
|
||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||
if input_ids is not None:
|
||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||
input_img_inx = 0
|
||||
for b_inx in input_image_sizes.keys():
|
||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||
input_img_inx += 1
|
||||
if input_img_latents is not None:
|
||||
assert input_img_inx == len(input_latents)
|
||||
|
||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||
else:
|
||||
input_emb = torch.cat([time_token, x], dim=1)
|
||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||
if input_is_list:
|
||||
image_embedding = output[:, -max(num_tokens):]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = []
|
||||
for i in range(x.size(0)):
|
||||
latent = x[i:i+1, :num_tokens[i]]
|
||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||
latents.append(latent)
|
||||
else:
|
||||
image_embedding = output[:, -num_tokens:]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||
|
||||
if return_past_key_values:
|
||||
return latents, past_key_values
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(attention_mask)
|
||||
|
||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||
timestep = timestep.to(x[0].dtype)
|
||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||
|
||||
model_out, pask_key_values = [], []
|
||||
for i in range(len(input_ids)):
|
||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||
model_out.append(temp_out)
|
||||
pask_key_values.append(temp_pask_key_values)
|
||||
|
||||
if len(model_out) == 3:
|
||||
cond, uncond, img_cond = model_out
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
elif len(model_out) == 2:
|
||||
cond, uncond = model_out
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
else:
|
||||
return model_out[0]
|
||||
|
||||
return torch.cat(model_out, dim=0), pask_key_values
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return OmniGenTransformerStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class OmniGenTransformerStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -5,6 +5,26 @@ from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
@@ -12,7 +32,7 @@ class PatchEmbed(torch.nn.Module):
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
|
||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
height = height // self.patch_size
|
||||
@@ -32,9 +52,9 @@ class PatchEmbed(torch.nn.Module):
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
@@ -47,10 +67,11 @@ class TimestepEmbeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False):
|
||||
def __init__(self, dim, single=False, dual=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
|
||||
self.dual = dual
|
||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
@@ -59,6 +80,12 @@ class AdaLayerNorm(torch.nn.Module):
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
elif self.dual:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||
norm_x = self.norm(x)
|
||||
x = norm_x * (1 + scale_msa) + shift_msa
|
||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
@@ -67,7 +94,7 @@ class AdaLayerNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class JointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
@@ -80,12 +107,38 @@ class JointAttention(torch.nn.Module):
|
||||
if not only_out_a:
|
||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||
|
||||
if use_rms_norm:
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.norm_q_a = None
|
||||
self.norm_k_a = None
|
||||
self.norm_q_b = None
|
||||
self.norm_k_b = None
|
||||
|
||||
|
||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||
batch_size = hidden_states.shape[0]
|
||||
qkv = to_qkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
if norm_q is not None:
|
||||
q = norm_q(q)
|
||||
if norm_k is not None:
|
||||
k = norm_k(k)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||
qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
|
||||
q = torch.concat([qa, qb], dim=2)
|
||||
k = torch.concat([ka, kb], dim=2)
|
||||
v = torch.concat([va, vb], dim=2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
@@ -97,16 +150,58 @@ class JointAttention(torch.nn.Module):
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
|
||||
class JointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
class SingleAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
|
||||
if use_rms_norm:
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.norm_q_a = None
|
||||
self.norm_k_a = None
|
||||
|
||||
|
||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||
batch_size = hidden_states.shape[0]
|
||||
qkv = to_qkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
if norm_q is not None:
|
||||
q = norm_q(q)
|
||||
if norm_k is not None:
|
||||
k = norm_k(k)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states_a):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.a_to_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class DualTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim, dual=True)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
@@ -124,7 +219,7 @@ class JointTransformerBlock(torch.nn.Module):
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
@@ -132,6 +227,58 @@ class JointTransformerBlock(torch.nn.Module):
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim, dual=dual)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
if dual:
|
||||
self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
if self.norm1_a.dual:
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||
else:
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
if self.norm1_a.dual:
|
||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
@@ -145,12 +292,12 @@ class JointTransformerBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class JointTransformerFinalBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim, single=True)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
@@ -177,15 +324,17 @@ class JointTransformerFinalBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class SD3DiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
|
||||
self.time_embedder = TimestepEmbeddings(256, 1536)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
|
||||
self.context_embedder = torch.nn.Linear(4096, 1536)
|
||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
|
||||
self.norm_out = AdaLayerNorm(1536, single=True)
|
||||
self.proj_out = torch.nn.Linear(1536, 64)
|
||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
|
||||
self.time_embedder = TimestepEmbeddings(256, embed_dim)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
|
||||
self.context_embedder = torch.nn.Linear(4096, embed_dim)
|
||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
|
||||
+ [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
|
||||
+ [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
|
||||
self.norm_out = AdaLayerNorm(embed_dim, single=True)
|
||||
self.proj_out = torch.nn.Linear(embed_dim, 64)
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
@@ -238,6 +387,24 @@ class SD3DiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_architecture(self, state_dict):
|
||||
embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
|
||||
num_layers = 100
|
||||
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
|
||||
num_layers -= 1
|
||||
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
|
||||
num_dual_blocks = 0
|
||||
while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
|
||||
num_dual_blocks += 1
|
||||
pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
|
||||
return {
|
||||
"embed_dim": embed_dim,
|
||||
"num_layers": num_layers,
|
||||
"use_rms_norm": use_rms_norm,
|
||||
"num_dual_blocks": num_dual_blocks,
|
||||
"pos_embed_max_size": pos_embed_max_size
|
||||
}
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
@@ -264,12 +431,17 @@ class SD3DiTStateDictConverter:
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "pos_embed.pos_embed":
|
||||
param = param.reshape((1, 192, 192, 1536))
|
||||
param = param.reshape((1, 192, 192, param.shape[-1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
@@ -283,7 +455,19 @@ class SD3DiTStateDictConverter:
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
|
||||
for key in merged_keys:
|
||||
param = torch.concat([
|
||||
state_dict_[key.replace("to_q", "to_q")],
|
||||
state_dict_[key.replace("to_q", "to_k")],
|
||||
state_dict_[key.replace("to_q", "to_v")],
|
||||
], dim=0)
|
||||
name = key.replace("to_q", "to_qkv")
|
||||
state_dict_.pop(key.replace("to_q", "to_q"))
|
||||
state_dict_.pop(key.replace("to_q", "to_k"))
|
||||
state_dict_.pop(key.replace("to_q", "to_v"))
|
||||
state_dict_[name] = param
|
||||
return state_dict_, self.infer_architecture(state_dict_)
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
@@ -291,478 +475,7 @@ class SD3DiTStateDictConverter:
|
||||
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
||||
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
||||
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
|
||||
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
|
||||
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
|
||||
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
|
||||
|
||||
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
||||
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
@@ -780,19 +493,59 @@ class SD3DiTStateDictConverter:
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||
}
|
||||
for i in range(40):
|
||||
rename_dict.update({
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
|
||||
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
|
||||
})
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
|
||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
||||
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
|
||||
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
||||
elif name == "model.diffusion_model.pos_embed":
|
||||
param = param.reshape((1, 192, 192, 1536))
|
||||
if name == "model.diffusion_model.pos_embed":
|
||||
pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
|
||||
param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
extra_kwargs = self.infer_architecture(state_dict_)
|
||||
num_layers = extra_kwargs["num_layers"]
|
||||
for name in [
|
||||
f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
|
||||
]:
|
||||
param = state_dict_[name]
|
||||
dim = param.shape[0] // 2
|
||||
param = torch.concat([param[dim:], param[:dim]], axis=0)
|
||||
state_dict_[name] = param
|
||||
return state_dict_, self.infer_architecture(state_dict_)
|
||||
|
||||
@@ -2,15 +2,18 @@ import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
|
||||
|
||||
|
||||
|
||||
class SD3TextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids)
|
||||
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
@@ -322,6 +325,11 @@ class SD3TextEncoder1StateDictConverter:
|
||||
if name == "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif ("text_encoders.clip_l.transformer." + name) in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict["text_encoders.clip_l.transformer." + name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
@@ -860,6 +868,11 @@ class SD3TextEncoder2StateDictConverter(SDXLTextEncoder2StateDictConverter):
|
||||
if name == "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif ("text_encoders.clip_g.transformer." + name) in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict["text_encoders.clip_g.transformer." + name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
|
||||
940
diffsynth/models/stepvideo_dit.py
Normal file
940
diffsynth/models/stepvideo_dit.py
Normal file
@@ -0,0 +1,940 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
from typing import Dict, Optional, Tuple
|
||||
import torch, math
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": nn.SiLU(),
|
||||
"silu": nn.SiLU(),
|
||||
"mish": nn.Mish(),
|
||||
"gelu": nn.GELU(),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in ACTIVATION_FUNCTIONS:
|
||||
return ACTIVATION_FUNCTIONS[act_fn]
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_1 = linear_cls(
|
||||
in_channels,
|
||||
time_embed_dim,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = linear_cls(
|
||||
cond_proj_dim,
|
||||
in_channels,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = linear_cls(
|
||||
time_embed_dim,
|
||||
time_embed_dim_out,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
if self.use_additional_conditions:
|
||||
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
||||
self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, resolution=None, nframe=None, fps=None):
|
||||
hidden_dtype = timestep.dtype
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
if self.use_additional_conditions:
|
||||
batch_size = timestep.shape[0]
|
||||
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
||||
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
||||
nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype)
|
||||
nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1)
|
||||
conditioning = timesteps_emb + resolution_emb + nframe_emb
|
||||
|
||||
if fps is not None:
|
||||
fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype)
|
||||
fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1)
|
||||
conditioning = conditioning + fps_emb
|
||||
else:
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
|
||||
|
||||
out = self.linear(self.silu(embedded_timestep))
|
||||
|
||||
return out, embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = nn.Linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def attn_processor(self, attn_type):
|
||||
if attn_type == 'torch':
|
||||
return self.torch_attn_func
|
||||
elif attn_type == 'parallel':
|
||||
return self.parallel_attn_func
|
||||
else:
|
||||
raise Exception('Not supported attention type...')
|
||||
|
||||
def torch_attn_func(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
causal=False,
|
||||
drop_rate=0.0,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(q.dtype)
|
||||
|
||||
if attn_mask is not None and attn_mask.ndim == 3: ## no head
|
||||
n_heads = q.shape[2]
|
||||
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
|
||||
|
||||
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(q.device)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
||||
)
|
||||
x = rearrange(x, 'b h s d -> b s h d')
|
||||
return x
|
||||
|
||||
|
||||
class RoPE1D:
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
self.base = freq
|
||||
self.F0 = F0
|
||||
self.scaling_factor = scaling_factor
|
||||
self.cache = {}
|
||||
|
||||
def get_cos_sin(self, D, seq_len, device, dtype):
|
||||
if (D, seq_len, device, dtype) not in self.cache:
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
||||
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = freqs.cos() # (Seq, Dim)
|
||||
sin = freqs.sin()
|
||||
self.cache[D, seq_len, device, dtype] = (cos, sin)
|
||||
return self.cache[D, seq_len, device, dtype]
|
||||
|
||||
@staticmethod
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
||||
assert pos1d.ndim == 2
|
||||
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
|
||||
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
|
||||
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
||||
|
||||
def __call__(self, tokens, positions):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* positions: batch_size x ntokens (t position of each token)
|
||||
output:
|
||||
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
D = tokens.size(3)
|
||||
assert positions.ndim == 2 # Batch, Seq
|
||||
cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
|
||||
tokens = self.apply_rope1d(tokens, positions, cos, sin)
|
||||
return tokens
|
||||
|
||||
|
||||
class RoPE3D(RoPE1D):
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
super(RoPE3D, self).__init__(freq, F0, scaling_factor)
|
||||
self.position_cache = {}
|
||||
|
||||
def get_mesh_3d(self, rope_positions, bsz):
|
||||
f, h, w = rope_positions
|
||||
|
||||
if f"{f}-{h}-{w}" not in self.position_cache:
|
||||
x = torch.arange(f, device='cpu')
|
||||
y = torch.arange(h, device='cpu')
|
||||
z = torch.arange(w, device='cpu')
|
||||
self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
|
||||
return self.position_cache[f"{f}-{h}-{w}"]
|
||||
|
||||
def __call__(self, tokens, rope_positions, ch_split, parallel=False):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* rope_positions: list of (f, h, w)
|
||||
output:
|
||||
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
assert sum(ch_split) == tokens.size(-1);
|
||||
|
||||
mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
|
||||
out = []
|
||||
for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
|
||||
cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
|
||||
|
||||
if parallel:
|
||||
pass
|
||||
else:
|
||||
mesh = mesh_grid[:, :, i].clone()
|
||||
x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
|
||||
out.append(x)
|
||||
|
||||
tokens = torch.cat(out, dim=-1)
|
||||
return tokens
|
||||
|
||||
|
||||
class SelfAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_rope = with_rope
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
if self.with_rope:
|
||||
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
|
||||
self.rope_ch_split = [64, 32, 32]
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
self.parallel = attn_type=='parallel'
|
||||
|
||||
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
|
||||
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
cu_seqlens=None,
|
||||
max_seqlen=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None
|
||||
):
|
||||
xqkv = self.wqkv(x)
|
||||
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
|
||||
|
||||
xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
if self.with_rope:
|
||||
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attn_mask=None
|
||||
):
|
||||
xq = self.wq(x)
|
||||
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
|
||||
|
||||
xkv = self.wkv(encoder_hidden_states)
|
||||
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
|
||||
|
||||
xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(gate, approximate=self.approximate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: Optional[int] = None,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim*mult if inner_dim is None else inner_dim
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
self.net = nn.ModuleList([
|
||||
GELU(dim, inner_dim, approximate="tanh", bias=bias),
|
||||
nn.Identity(),
|
||||
nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
])
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def modulate(x, scale, shift):
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
def gate(x, gate):
|
||||
x = gate * x
|
||||
return x
|
||||
|
||||
|
||||
class StepVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
attention_head_dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = False,
|
||||
attention_type: str = 'parallel'
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
|
||||
|
||||
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
attn_mask = None,
|
||||
rope_positions: list = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
torch.clone(chunk) for chunk in (self.scale_shift_table[None].to(dtype=q.dtype, device=q.device) + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
|
||||
|
||||
attn_q = self.attn1(
|
||||
scale_shift_q,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
q = gate(attn_q, gate_msa) + q
|
||||
|
||||
attn_q = self.attn2(
|
||||
q,
|
||||
kv,
|
||||
attn_mask
|
||||
)
|
||||
|
||||
q = attn_q + q
|
||||
|
||||
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
|
||||
|
||||
ff_output = self.ff(scale_shift_q)
|
||||
|
||||
q = gate(ff_output, gate_mlp) + q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=64,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
layer_norm=False,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.flatten = flatten
|
||||
self.layer_norm = layer_norm
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent).to(latent.dtype)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if self.layer_norm:
|
||||
latent = self.norm(latent)
|
||||
|
||||
return latent
|
||||
|
||||
|
||||
class StepVideoModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 48,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 64,
|
||||
num_layers: int = 48,
|
||||
dropout: float = 0.0,
|
||||
patch_size: int = 1,
|
||||
norm_type: str = "ada_norm_single",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
use_additional_conditions: Optional[bool] = False,
|
||||
caption_channels: Optional[int]|list|tuple = [6144, 1024],
|
||||
attention_type: Optional[str] = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
StepVideoTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_type=attention_type
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Output blocks.
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
||||
)
|
||||
|
||||
if isinstance(caption_channels, int):
|
||||
caption_channel = caption_channels
|
||||
else:
|
||||
caption_channel, clip_channel = caption_channels
|
||||
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
|
||||
|
||||
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channel, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
self.parallel = attention_type=='parallel'
|
||||
|
||||
def patchfy(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
|
||||
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
|
||||
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
|
||||
encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
|
||||
for i, kv_len in enumerate(kv_seqlens):
|
||||
mask[i, :, :kv_len] = 1
|
||||
return encoder_hidden_states, mask
|
||||
|
||||
|
||||
def block_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None,
|
||||
parallel=True
|
||||
):
|
||||
for block in tqdm(self.transformer_blocks, desc="Transformer blocks"):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
attn_mask=attn_mask,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: torch.Tensor=None,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
|
||||
|
||||
bsz, frame, _, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
hidden_states = self.patchfy(hidden_states)
|
||||
len_frame = hidden_states.shape[1]
|
||||
|
||||
if self.use_additional_conditions:
|
||||
added_cond_kwargs = {
|
||||
"resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"fps": fps
|
||||
}
|
||||
else:
|
||||
added_cond_kwargs = {}
|
||||
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
|
||||
|
||||
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
|
||||
clip_embedding = self.clip_projection(encoder_hidden_states_2)
|
||||
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
|
||||
|
||||
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
|
||||
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
|
||||
|
||||
hidden_states = self.block_forward(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
rope_positions=[frame, height, width],
|
||||
attn_mask=attn_mask,
|
||||
parallel=self.parallel
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
|
||||
|
||||
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
|
||||
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=embedded_timestep.dtype, device=embedded_timestep.device) + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
|
||||
|
||||
if return_dict:
|
||||
return {'x': output}
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return StepVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
class StepVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
553
diffsynth/models/stepvideo_text_encoder.py
Normal file
553
diffsynth/models/stepvideo_text_encoder.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .stepvideo_dit import RMSNorm
|
||||
from safetensors.torch import load_file
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from einops import rearrange
|
||||
import json
|
||||
from typing import List
|
||||
from functools import wraps
|
||||
import warnings
|
||||
|
||||
|
||||
|
||||
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
||||
def __init__(self, device=None):
|
||||
self.device = device
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if getattr(func, '__module__', None) == 'torch.nn.init':
|
||||
if 'tensor' in kwargs:
|
||||
return kwargs['tensor']
|
||||
else:
|
||||
return args[0]
|
||||
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
|
||||
kwargs['device'] = self.device
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_empty_init(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with EmptyInitOnDevice('cpu'):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
class LLaMaEmbedding(nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cfg,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = cfg.hidden_size
|
||||
self.params_dtype = cfg.params_dtype
|
||||
self.fp32_residual_connection = cfg.fp32_residual_connection
|
||||
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
|
||||
self.word_embeddings = torch.nn.Embedding(
|
||||
cfg.padded_vocab_size, self.hidden_size,
|
||||
)
|
||||
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
|
||||
|
||||
def forward(self, input_ids):
|
||||
# Embeddings.
|
||||
if self.embedding_weights_in_fp32:
|
||||
self.word_embeddings = self.word_embeddings.to(torch.float32)
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.embedding_weights_in_fp32:
|
||||
embeddings = embeddings.to(self.params_dtype)
|
||||
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
|
||||
|
||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
embeddings = embeddings.float()
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
|
||||
class StepChatTokenizer:
|
||||
"""Step Chat Tokenizer"""
|
||||
|
||||
def __init__(
|
||||
self, model_file, name="StepChatTokenizer",
|
||||
bot_token="<|BOT|>", # Begin of Turn
|
||||
eot_token="<|EOT|>", # End of Turn
|
||||
call_start_token="<|CALL_START|>", # Call Start
|
||||
call_end_token="<|CALL_END|>", # Call End
|
||||
think_start_token="<|THINK_START|>", # Think Start
|
||||
think_end_token="<|THINK_END|>", # Think End
|
||||
mask_start_token="<|MASK_1e69f|>", # Mask start
|
||||
mask_end_token="<|UNMASK_1e69f|>", # Mask end
|
||||
):
|
||||
import sentencepiece
|
||||
|
||||
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
|
||||
|
||||
self._vocab = {}
|
||||
self._inv_vocab = {}
|
||||
|
||||
self._special_tokens = {}
|
||||
self._inv_special_tokens = {}
|
||||
|
||||
self._t5_tokens = []
|
||||
|
||||
for idx in range(self._tokenizer.get_piece_size()):
|
||||
text = self._tokenizer.id_to_piece(idx)
|
||||
self._inv_vocab[idx] = text
|
||||
self._vocab[text] = idx
|
||||
|
||||
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
|
||||
self._special_tokens[text] = idx
|
||||
self._inv_special_tokens[idx] = text
|
||||
|
||||
self._unk_id = self._tokenizer.unk_id()
|
||||
self._bos_id = self._tokenizer.bos_id()
|
||||
self._eos_id = self._tokenizer.eos_id()
|
||||
|
||||
for token in [
|
||||
bot_token, eot_token, call_start_token, call_end_token,
|
||||
think_start_token, think_end_token
|
||||
]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
assert token in self._special_tokens, f"Token '{token}' is not a special token"
|
||||
|
||||
for token in [mask_start_token, mask_end_token]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
|
||||
self._bot_id = self._tokenizer.piece_to_id(bot_token)
|
||||
self._eot_id = self._tokenizer.piece_to_id(eot_token)
|
||||
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
|
||||
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
|
||||
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
|
||||
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
|
||||
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
|
||||
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
|
||||
|
||||
self._underline_id = self._tokenizer.piece_to_id("\u2581")
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self._vocab
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self._inv_vocab
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self._tokenizer.vocab_size()
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self._tokenizer.encode_as_ids(text)
|
||||
|
||||
def detokenize(self, token_ids: List[int]) -> str:
|
||||
return self._tokenizer.decode_ids(token_ids)
|
||||
|
||||
|
||||
class Tokens:
|
||||
def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
self.cu_input_ids = cu_input_ids
|
||||
self.cu_seqlens = cu_seqlens
|
||||
self.max_seq_len = max_seq_len
|
||||
def to(self, device):
|
||||
self.input_ids = self.input_ids.to(device)
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
self.cu_input_ids = self.cu_input_ids.to(device)
|
||||
self.cu_seqlens = self.cu_seqlens.to(device)
|
||||
return self
|
||||
|
||||
class Wrapped_StepChatTokenizer(StepChatTokenizer):
|
||||
def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
|
||||
# [bos, ..., eos, pad, pad, ..., pad]
|
||||
self.BOS = 1
|
||||
self.EOS = 2
|
||||
self.PAD = 2
|
||||
out_tokens = []
|
||||
attn_mask = []
|
||||
if len(text) == 0:
|
||||
part_tokens = [self.BOS] + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
else:
|
||||
for part in text:
|
||||
part_tokens = self.tokenize(part)
|
||||
part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
|
||||
part_tokens = [self.BOS] + part_tokens + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
|
||||
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
|
||||
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
|
||||
|
||||
# padding y based on tp size
|
||||
padded_len = 0
|
||||
padded_flag = True if padded_len > 0 else False
|
||||
if padded_flag:
|
||||
pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
|
||||
pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
|
||||
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
|
||||
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
|
||||
|
||||
# cu_seqlens
|
||||
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
|
||||
seqlen = attn_mask.sum(dim=1).tolist()
|
||||
cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
|
||||
max_seq_len = max(seqlen)
|
||||
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
|
||||
|
||||
|
||||
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
|
||||
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
|
||||
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
|
||||
if hasattr(torch.ops.Optimus, "fwd"):
|
||||
results = torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
|
||||
else:
|
||||
warnings.warn("Cannot load `torch.ops.Optimus.fwd`. Using `torch.nn.functional.scaled_dot_product_attention` instead.")
|
||||
results = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=softmax_scale).transpose(1, 2)
|
||||
return results
|
||||
|
||||
|
||||
class FlashSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
|
||||
if cu_seqlens is None:
|
||||
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
|
||||
else:
|
||||
raise ValueError('cu_seqlens is not supported!')
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def safediv(n, d):
|
||||
q, r = divmod(n, d)
|
||||
assert r == 0
|
||||
return q
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
def __init__(self, cfg, layer_id=None):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.max_seq_len = cfg.seq_length
|
||||
self.use_flash_attention = cfg.use_flash_attn
|
||||
assert self.use_flash_attention, 'FlashAttention is required!'
|
||||
|
||||
self.n_groups = cfg.num_attention_groups
|
||||
self.tp_size = 1
|
||||
self.n_local_heads = cfg.num_attention_heads
|
||||
self.n_local_groups = self.n_groups
|
||||
|
||||
self.wqkv = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
|
||||
bias=False,
|
||||
)
|
||||
self.wo = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
assert self.use_flash_attention, 'non-Flash attention not supported yet.'
|
||||
self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
|
||||
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
seqlen, bsz, dim = x.shape
|
||||
xqkv = self.wqkv(x)
|
||||
|
||||
xq, xkv = torch.split(
|
||||
xqkv,
|
||||
(dim // self.tp_size,
|
||||
self.head_dim*2*self.n_groups // self.tp_size
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# gather on 1st dimention
|
||||
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
|
||||
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
|
||||
xk, xv = xkv.chunk(2, -1)
|
||||
|
||||
# rotary embedding + flash attn
|
||||
xq = rearrange(xq, "s b h d -> b s h d")
|
||||
xk = rearrange(xk, "s b h d -> b s h d")
|
||||
xv = rearrange(xv, "s b h d -> b s h d")
|
||||
|
||||
q_per_kv = self.n_local_heads // self.n_local_groups
|
||||
if q_per_kv > 1:
|
||||
b, s, h, d = xk.size()
|
||||
if h == 1:
|
||||
xk = xk.expand(b, s, q_per_kv, d)
|
||||
xv = xv.expand(b, s, q_per_kv, d)
|
||||
else:
|
||||
''' To cover the cases where h > 1, we have
|
||||
the following implementation, which is equivalent to:
|
||||
xk = xk.repeat_interleave(q_per_kv, dim=-2)
|
||||
xv = xv.repeat_interleave(q_per_kv, dim=-2)
|
||||
but can avoid calling aten::item() that involves cpu.
|
||||
'''
|
||||
idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
|
||||
xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
|
||||
if self.use_flash_attention:
|
||||
output = self.core_attention(xq, xk, xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seq_len=max_seq_len)
|
||||
# reduce-scatter only support first dimention now
|
||||
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
|
||||
else:
|
||||
xq, xk, xv = [
|
||||
rearrange(x, "b s ... -> s b ...").contiguous()
|
||||
for x in (xq, xk, xv)
|
||||
]
|
||||
output = self.core_attention(xq, xk, xv, mask)
|
||||
output = self.wo(output)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
layer_id: int,
|
||||
multiple_of: int=256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
def swiglu(x):
|
||||
x = torch.chunk(x, 2, dim=-1)
|
||||
return F.silu(x[0]) * x[1]
|
||||
self.swiglu = swiglu
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
dim,
|
||||
2 * hidden_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.swiglu(self.w1(x))
|
||||
output = self.w2(x)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, cfg, layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = cfg.num_attention_heads
|
||||
self.dim = cfg.hidden_size
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.attention = MultiQueryAttention(
|
||||
cfg,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
cfg,
|
||||
dim=cfg.hidden_size,
|
||||
hidden_dim=cfg.ffn_hidden_size,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
self.ffn_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
residual = self.attention.forward(
|
||||
self.attention_norm(x), mask,
|
||||
cu_seqlens, max_seq_len
|
||||
)
|
||||
h = x + residual
|
||||
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + ffn_res
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_size=8192,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = config.num_layers
|
||||
self.layers = self._build_layers(config)
|
||||
|
||||
def _build_layers(self, config):
|
||||
layers = torch.nn.ModuleList()
|
||||
for layer_id in range(self.num_layers):
|
||||
layers.append(
|
||||
TransformerBlock(
|
||||
config,
|
||||
layer_id=layer_id + 1 ,
|
||||
)
|
||||
)
|
||||
return layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens=None,
|
||||
max_seq_len=None,
|
||||
):
|
||||
|
||||
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
|
||||
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
|
||||
|
||||
for lid, layer in enumerate(self.layers):
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens,
|
||||
max_seq_len,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step1Model(PreTrainedModel):
|
||||
config_class=PretrainedConfig
|
||||
@with_empty_init
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.tok_embeddings = LLaMaEmbedding(config)
|
||||
self.transformer = Transformer(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
|
||||
hidden_states = self.transformer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class STEP1TextEncoder(torch.nn.Module):
|
||||
def __init__(self, model_dir, max_length=320):
|
||||
super(STEP1TextEncoder, self).__init__()
|
||||
self.max_length = max_length
|
||||
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
|
||||
text_encoder = Step1Model.from_pretrained(model_dir)
|
||||
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, torch_dtype=torch.bfloat16):
|
||||
model = STEP1TextEncoder(path).to(torch_dtype)
|
||||
return model
|
||||
|
||||
@torch.no_grad
|
||||
def forward(self, prompts, with_mask=True, max_length=None, device="cuda"):
|
||||
self.device = device
|
||||
with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
|
||||
if type(prompts) is str:
|
||||
prompts = [prompts]
|
||||
|
||||
txt_tokens = self.text_tokenizer(
|
||||
prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
y = self.text_encoder(
|
||||
txt_tokens.input_ids.to(self.device),
|
||||
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
|
||||
)
|
||||
y_mask = txt_tokens.attention_mask
|
||||
return y.transpose(0,1), y_mask
|
||||
|
||||
1132
diffsynth/models/stepvideo_vae.py
Normal file
1132
diffsynth/models/stepvideo_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -44,6 +44,7 @@ def get_timestep_embedding(
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
computation_device = None,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
@@ -57,11 +58,11 @@ def get_timestep_embedding(
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = torch.exp(exponent).to(timesteps.device)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
@@ -81,11 +82,12 @@ def get_timestep_embedding(
|
||||
|
||||
|
||||
class TemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.computation_device = computation_device
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
@@ -93,6 +95,7 @@ class TemporalTimesteps(torch.nn.Module):
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
computation_device=self.computation_device,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
@@ -107,6 +107,60 @@ class TileWorker:
|
||||
|
||||
|
||||
|
||||
class FastTileWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound):
|
||||
_, _, H, W = data.shape
|
||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else h + 1,
|
||||
pad if is_bound[1] else H - h,
|
||||
pad if is_bound[2] else w + 1,
|
||||
pad if is_bound[3] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
||||
mask = rearrange(mask, "H W -> 1 H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
B, C, H, W = model_input.shape
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
|
||||
values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = H - tile_size, H
|
||||
if w_ > W: w, w_ = W - tile_size, W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
# Forward
|
||||
hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
|
||||
|
||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
||||
weight[:, :, hl:hr, wl:wr] += mask
|
||||
values /= weight
|
||||
return values
|
||||
|
||||
|
||||
|
||||
class TileWorker2Dto3D:
|
||||
"""
|
||||
Process 3D tensors, but only enable TileWorker on 2D.
|
||||
|
||||
@@ -1,7 +1,56 @@
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
from contextlib import contextmanager
|
||||
import hashlib
|
||||
|
||||
@contextmanager
|
||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = torch.nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(device)
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if include_buffers:
|
||||
tensor_constructors_to_patch = {
|
||||
torch_function_name: getattr(torch, torch_function_name)
|
||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||
}
|
||||
else:
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
try:
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = register_empty_buffer
|
||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = old_register_buffer
|
||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
@@ -31,7 +80,7 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
@@ -94,3 +143,40 @@ def search_for_files(folder, extensions):
|
||||
files.append(folder)
|
||||
break
|
||||
return files
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
254
diffsynth/models/wanx_text_encoder.py
Normal file
254
diffsynth/models/wanx_text_encoder.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
clamp = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp, max=clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super(T5LayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||
self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
||||
assert dim_attn % num_heads == 0
|
||||
super(T5Attention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
|
||||
# attention bias
|
||||
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
||||
if pos_bias is not None:
|
||||
attn_bias += pos_bias
|
||||
if mask is not None:
|
||||
assert mask.ndim in [2, 3]
|
||||
mask = mask.view(b, 1, 1,
|
||||
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
||||
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, n * c)
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_ffn, dropout=0.1):
|
||||
super(T5FeedForward, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) * self.gate(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def forward(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super(T5RelativeEmbedding, self).__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def forward(self, lq, lk):
|
||||
device = self.embedding.weight.device
|
||||
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
||||
# torch.arange(lq).unsqueeze(1).to(device)
|
||||
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
||||
torch.arange(lq, device=device).unsqueeze(1)
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
||||
0) # [1, N, Lq, Lk]
|
||||
return rel_pos_embeds.contiguous()
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).long() * num_buckets
|
||||
rel_pos = torch.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = 0
|
||||
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)).long()
|
||||
rel_pos_large = torch.min(
|
||||
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
||||
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
||||
return rel_buckets
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, T5LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
elif isinstance(m, T5FeedForward):
|
||||
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
||||
elif isinstance(m, T5Attention):
|
||||
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
||||
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
||||
elif isinstance(m, T5RelativeEmbedding):
|
||||
nn.init.normal_(
|
||||
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
||||
|
||||
|
||||
class WanXTextEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
num_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.1):
|
||||
super(WanXTextEncoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
794
diffsynth/models/wanx_vae.py
Normal file
794
diffsynth/models/wanx_vae.py
Normal file
@@ -0,0 +1,794 @@
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
def block_causal_mask(x, block_size):
|
||||
# params
|
||||
b, n, s, _, device = *x.size(), x.device
|
||||
assert s % block_size == 0
|
||||
num_blocks = s // block_size
|
||||
|
||||
# build mask
|
||||
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
||||
for i in range(num_blocks):
|
||||
mask[:, :,
|
||||
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim * 2, (3, 1, 1),
|
||||
padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim, (3, 1, 1),
|
||||
stride=(2, 1, 1),
|
||||
padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
||||
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
#attn_mask=block_causal_mask(q, block_size=h * w)
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class VideoVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=96,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
class WanXVideoVAE(nn.Module):
|
||||
|
||||
def __init__(self, z_dim=16):
|
||||
super().__init__()
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = torch.tensor(mean)
|
||||
self.std = torch.tensor(std)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 8
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, _, H, W = data.shape
|
||||
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
||||
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
||||
|
||||
h = repeat(h, "H -> H W", H=H, W=W)
|
||||
w = repeat(w, "W -> H W", H=H, W=W)
|
||||
|
||||
mask = torch.stack([h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = hidden_states.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = T * 4 - 3
|
||||
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
||||
).to(dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
target_h = h * self.upsampling_factor
|
||||
target_w = w * self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float().clamp_(-1, 1)
|
||||
return values
|
||||
|
||||
|
||||
def tiled_encode(self, video, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = video.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = (T + 3) // 4
|
||||
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
||||
).to(dtype=video.dtype, device=data_device)
|
||||
|
||||
target_h = h // self.upsampling_factor
|
||||
target_w = w // self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float()
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, self.scale)
|
||||
return x.float()
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, self.scale)
|
||||
return video.float().clamp_(-1, 1)
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)):
|
||||
|
||||
videos = [video.to("cpu") for video in videos]
|
||||
hidden_states = []
|
||||
for video in videos:
|
||||
video = video.unsqueeze(0)
|
||||
if tiled:
|
||||
assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}"
|
||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||
else:
|
||||
hidden_state = self.single_encode(video, device)
|
||||
hidden_state = hidden_state.squeeze(0)
|
||||
hidden_states.append(hidden_state)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanXVideoVAEStateDictConverter()
|
||||
|
||||
|
||||
class WanXVideoVAEStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict['model_state']:
|
||||
state_dict_['model.' + name] = state_dict['model_state'][name]
|
||||
return state_dict_
|
||||
@@ -7,5 +7,8 @@ from .hunyuan_image import HunyuanDiTImagePipeline
|
||||
from .svd_video import SVDVideoPipeline
|
||||
from .flux_image import FluxImagePipeline
|
||||
from .cog_video import CogVideoPipeline
|
||||
from .omnigen_image import OmnigenImagePipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
from .hunyuan_video import HunyuanVideoPipeline
|
||||
from .step_video import StepVideoPipeline
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
|
||||
@@ -1,19 +1,32 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.transforms import GaussianBlur
|
||||
|
||||
|
||||
|
||||
class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.height_division_factor = height_division_factor
|
||||
self.width_division_factor = width_division_factor
|
||||
self.cpu_offload = False
|
||||
self.model_names = []
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width):
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
||||
return height, width
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
@@ -35,21 +48,30 @@ class BasePipeline(torch.nn.Module):
|
||||
return video
|
||||
|
||||
|
||||
def merge_latents(self, value, latents, masks, scales):
|
||||
height, width = value.shape[-2:]
|
||||
weight = torch.ones_like(value)
|
||||
for latent, mask, scale in zip(latents, masks, scales):
|
||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1)
|
||||
value[mask] += latent[mask] * scale
|
||||
weight[mask] += scale
|
||||
value /= weight
|
||||
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
||||
if len(latents) > 0:
|
||||
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
||||
height, width = value.shape[-2:]
|
||||
weight = torch.ones_like(value)
|
||||
for latent, mask, scale in zip(latents, masks, scales):
|
||||
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
||||
mask = blur(mask)
|
||||
value += latent * mask * scale
|
||||
weight += mask * scale
|
||||
value /= weight
|
||||
return value
|
||||
|
||||
|
||||
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback):
|
||||
noise_pred_global = inference_callback(prompt_emb_global)
|
||||
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
||||
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
||||
if special_kwargs is None:
|
||||
noise_pred_global = inference_callback(prompt_emb_global)
|
||||
else:
|
||||
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
||||
if special_local_kwargs_list is None:
|
||||
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
||||
else:
|
||||
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
||||
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
||||
return noise_pred
|
||||
|
||||
@@ -65,9 +87,11 @@ class BasePipeline(torch.nn.Module):
|
||||
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
||||
return prompt, local_prompts, masks, mask_scales
|
||||
|
||||
|
||||
def enable_cpu_offload(self):
|
||||
self.cpu_offload = True
|
||||
|
||||
|
||||
def load_models_to_device(self, loadmodel_names=[]):
|
||||
# only load models to device if cpu_offload is enabled
|
||||
if not self.cpu_offload:
|
||||
@@ -77,11 +101,27 @@ class BasePipeline(torch.nn.Module):
|
||||
if model_name not in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.cpu()
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
# load the needed models to device
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(self.device)
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
# fresh the cuda cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
||||
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
return noise
|
||||
|
||||
@@ -13,7 +13,7 @@ from einops import rearrange
|
||||
class CogVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
||||
self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
|
||||
self.prompter = CogPrompter()
|
||||
# models
|
||||
@@ -73,9 +73,12 @@ class CogVideoPipeline(BasePipeline):
|
||||
tiled=False,
|
||||
tile_size=(60, 90),
|
||||
tile_stride=(30, 45),
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -83,7 +86,8 @@ class CogVideoPipeline(BasePipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 16, num_frames // 4 + 1, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((1, 16, num_frames // 4 + 1, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
|
||||
|
||||
if denoising_strength == 1.0:
|
||||
latents = noise.clone()
|
||||
else:
|
||||
|
||||
@@ -139,6 +139,8 @@ def lets_dance_xl(
|
||||
# 0. Text embedding alignment (only for video processing)
|
||||
if encoder_hidden_states.shape[0] != sample.shape[0]:
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
|
||||
if add_text_embeds.shape[0] != sample.shape[0]:
|
||||
add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
|
||||
|
||||
# 1. ControlNet
|
||||
controlnet_insert_block_id = 22
|
||||
@@ -204,7 +206,7 @@ def lets_dance_xl(
|
||||
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
||||
hidden_states, _, _, _ = block(
|
||||
hidden_states_input[batch_id: batch_id_],
|
||||
time_emb,
|
||||
time_emb[batch_id: batch_id_],
|
||||
text_emb[batch_id: batch_id_],
|
||||
res_stack,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
|
||||
@@ -1,33 +1,144 @@
|
||||
from ..models import ModelManager, FluxDiT, FluxTextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
|
||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import FluxPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from ..models.tiler import FastTileWorker
|
||||
from transformers import SiglipVisionModel
|
||||
from copy import deepcopy
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.prompter = FluxPrompter()
|
||||
# models
|
||||
self.text_encoder_1: FluxTextEncoder1 = None
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: FluxTextEncoder2 = None
|
||||
self.dit: FluxDiT = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
self.controlnet: FluxMultiControlNetManager = None
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
T5LayerNorm: AutoWrappedModule,
|
||||
T5DenseActDense: AutoWrappedModule,
|
||||
T5DenseGatedActDense: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1")
|
||||
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("flux_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
|
||||
@@ -36,14 +147,29 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
|
||||
self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
|
||||
|
||||
# ControlNets
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
|
||||
model_manager.fetch_model("flux_controlnet", config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = FluxMultiControlNetManager(controlnet_units)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -58,40 +184,216 @@ class FluxImagePipeline(BasePipeline):
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
|
||||
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=0.0):
|
||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||
latent_image_ids = self.dit.prepare_image_ids(latents)
|
||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
||||
|
||||
|
||||
def apply_controlnet_mask_on_latents(self, latents, mask):
|
||||
mask = (self.preprocess_image(mask) + 1) / 2
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
mask = mask.to(dtype=self.torch_dtype, device=self.device)
|
||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
||||
latents = torch.concat([latents, mask], dim=1)
|
||||
return latents
|
||||
|
||||
|
||||
def apply_controlnet_mask_on_image(self, image, mask):
|
||||
mask = mask.resize(image.size)
|
||||
mask = self.preprocess_image(mask).mean(dim=[0, 1])
|
||||
image = np.array(image)
|
||||
image[mask > 0] = 0
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
|
||||
def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
|
||||
if isinstance(controlnet_image, Image.Image):
|
||||
controlnet_image = [controlnet_image] * len(self.controlnet.processors)
|
||||
|
||||
controlnet_frames = []
|
||||
for i in range(len(self.controlnet.processors)):
|
||||
# image annotator
|
||||
image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
|
||||
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
|
||||
image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
|
||||
|
||||
# image to tensor
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# vae encoder
|
||||
image = self.encode_image(image, **tiler_kwargs)
|
||||
if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
|
||||
image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
|
||||
|
||||
# store it
|
||||
controlnet_frames.append(image)
|
||||
return controlnet_frames
|
||||
|
||||
|
||||
def prepare_ipadapter_inputs(self, images, height=384, width=384):
|
||||
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
|
||||
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
|
||||
return torch.cat(images, dim=0)
|
||||
|
||||
|
||||
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
|
||||
# inpaint noise
|
||||
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
|
||||
# merge noise
|
||||
weight = torch.ones_like(inpaint_noise)
|
||||
inpaint_noise[fg_mask] = pred_noise[fg_mask]
|
||||
inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
|
||||
weight[bg_mask] += background_weight
|
||||
inpaint_noise /= weight
|
||||
return inpaint_noise
|
||||
|
||||
|
||||
def preprocess_masks(self, masks, height, width, dim):
|
||||
out_masks = []
|
||||
for mask in masks:
|
||||
mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
||||
out_masks.append(mask)
|
||||
return out_masks
|
||||
|
||||
|
||||
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False):
|
||||
fg_mask, bg_mask = None, None
|
||||
if enable_eligen_inpaint:
|
||||
masks_ = deepcopy(entity_masks)
|
||||
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
|
||||
fg_masks = (fg_masks > 0).float()
|
||||
fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
|
||||
bg_mask = ~fg_mask
|
||||
entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
|
||||
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
|
||||
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
|
||||
return entity_prompts, entity_masks, fg_mask, bg_mask
|
||||
|
||||
|
||||
def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride):
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
input_latents = None
|
||||
return latents, input_latents
|
||||
|
||||
|
||||
def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
|
||||
if ipadapter_images is not None:
|
||||
self.load_models_to_device(['ipadapter_image_encoder'])
|
||||
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
|
||||
self.load_models_to_device(['ipadapter'])
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega
|
||||
|
||||
|
||||
def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative):
|
||||
if controlnet_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
|
||||
if len(masks) > 0 and controlnet_inpaint_mask is not None:
|
||||
print("The controlnet_inpaint_mask will be overridden by masks.")
|
||||
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
|
||||
else:
|
||||
local_controlnet_kwargs = None
|
||||
else:
|
||||
controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
|
||||
controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {}
|
||||
return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs
|
||||
|
||||
|
||||
def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale):
|
||||
if eligen_entity_masks is not None:
|
||||
entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint)
|
||||
if enable_eligen_on_negative and cfg_scale != 1.0:
|
||||
entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
|
||||
entity_masks_nega = entity_masks_posi
|
||||
else:
|
||||
entity_prompt_emb_nega, entity_masks_nega = None, None
|
||||
else:
|
||||
entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None
|
||||
fg_mask, bg_mask = None, None
|
||||
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
|
||||
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
|
||||
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
|
||||
|
||||
|
||||
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
|
||||
# Extend prompt
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt,
|
||||
local_prompts= None,
|
||||
masks= None,
|
||||
mask_scales= None,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=0.0,
|
||||
embedded_guidance=3.5,
|
||||
t5_sequence_length=512,
|
||||
# Image
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=None,
|
||||
# Steps
|
||||
num_inference_steps=30,
|
||||
# local prompts
|
||||
local_prompts=(),
|
||||
masks=(),
|
||||
mask_scales=(),
|
||||
# ControlNet
|
||||
controlnet_image=None,
|
||||
controlnet_inpaint_mask=None,
|
||||
enable_controlnet_on_negative=False,
|
||||
# IP-Adapter
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
# EliGen
|
||||
eligen_entity_prompts=None,
|
||||
eligen_entity_masks=None,
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=False,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -99,41 +401,53 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
|
||||
|
||||
# Extend prompt
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
|
||||
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
# Prompt
|
||||
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Entity control
|
||||
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
|
||||
|
||||
# IP-Adapter
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
|
||||
|
||||
# ControlNets
|
||||
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit'])
|
||||
self.load_models_to_device(['dit', 'controlnet'])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
inference_callback = lambda prompt_emb_posi: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
|
||||
# Positive side
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
|
||||
)
|
||||
|
||||
# Inpaint
|
||||
if enable_eligen_inpaint:
|
||||
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
|
||||
|
||||
# Classifier-free guidance
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
|
||||
# Negative side
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -148,8 +462,185 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
# Decode image
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.decode_image(latents, **tiler_kwargs)
|
||||
|
||||
# Offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def check(self, dit: FluxDiT, hidden_states, conditioning):
|
||||
inp = hidden_states.clone()
|
||||
temb_ = conditioning.clone()
|
||||
modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = hidden_states.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
def lets_dance_flux(
|
||||
dit: FluxDiT,
|
||||
controlnet: FluxMultiControlNetManager = None,
|
||||
hidden_states=None,
|
||||
timestep=None,
|
||||
prompt_emb=None,
|
||||
pooled_prompt_emb=None,
|
||||
guidance=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
controlnet_frames=None,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
entity_prompt_emb=None,
|
||||
entity_masks=None,
|
||||
ipadapter_kwargs_list={},
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
def flux_forward_fn(hl, hr, wl, wr):
|
||||
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
|
||||
return lets_dance_flux(
|
||||
dit=dit,
|
||||
controlnet=controlnet,
|
||||
hidden_states=hidden_states[:, :, hl: hr, wl: wr],
|
||||
timestep=timestep,
|
||||
prompt_emb=prompt_emb,
|
||||
pooled_prompt_emb=pooled_prompt_emb,
|
||||
guidance=guidance,
|
||||
text_ids=text_ids,
|
||||
image_ids=None,
|
||||
controlnet_frames=tiled_controlnet_frames,
|
||||
tiled=False,
|
||||
**kwargs
|
||||
)
|
||||
return FastTileWorker().tiled_forward(
|
||||
flux_forward_fn,
|
||||
hidden_states,
|
||||
tile_size=tile_size,
|
||||
tile_stride=tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
controlnet_extra_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"prompt_emb": prompt_emb,
|
||||
"pooled_prompt_emb": pooled_prompt_emb,
|
||||
"guidance": guidance,
|
||||
"text_ids": text_ids,
|
||||
"image_ids": image_ids,
|
||||
"tiled": tiled,
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride,
|
||||
}
|
||||
controlnet_res_stack, controlnet_single_res_stack = controlnet(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = dit.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
|
||||
if dit.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = dit.patchify(hidden_states)
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if tea_cache_update:
|
||||
hidden_states = tea_cache.update(hidden_states)
|
||||
else:
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
|
||||
# Single Blocks
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(hidden_states)
|
||||
|
||||
hidden_states = dit.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = dit.final_proj_out(hidden_states)
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -125,7 +125,7 @@ class ImageSizeManager:
|
||||
class HunyuanDiTImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
||||
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
||||
self.prompter = HunyuanDiTPrompter()
|
||||
self.image_size_manager = ImageSizeManager()
|
||||
@@ -226,14 +226,17 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
||||
|
||||
265
diffsynth/pipelines/hunyuan_video.py
Normal file
265
diffsynth/pipelines/hunyuan_video.py
Normal file
@@ -0,0 +1,265 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
|
||||
self.prompter = HunyuanVideoPrompter()
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.vae_encoder: HunyuanVideoVAEEncoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
self.vram_management = False
|
||||
|
||||
|
||||
def enable_vram_management(self):
|
||||
self.vram_management = True
|
||||
self.enable_cpu_offload()
|
||||
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("hunyuan_video_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
if enable_vram_management:
|
||||
pipe.enable_vram_management()
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
|
||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
|
||||
def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
|
||||
tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
|
||||
tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
|
||||
latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device=None,
|
||||
height=720,
|
||||
width=1280,
|
||||
num_frames=129,
|
||||
embedded_guidance=6.0,
|
||||
cfg_scale=1.0,
|
||||
num_inference_steps=30,
|
||||
tea_cache_l1_thresh=None,
|
||||
tile_size=(17, 30, 30),
|
||||
tile_stride=(12, 20, 20),
|
||||
step_processor=None,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Initialize noise
|
||||
rand_device = self.device if rand_device is None else rand_device
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||
if input_video is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
input_video = self.preprocess_images(input_video)
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
# Inference
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# (Experimental feature, may be removed in the future)
|
||||
if step_processor is not None:
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
rendered_frames = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents, to_final=True)
|
||||
rendered_frames = self.vae_decoder.decode_video(rendered_frames, **tiler_kwargs)
|
||||
rendered_frames = self.tensor2video(rendered_frames[0])
|
||||
rendered_frames = step_processor(rendered_frames, original_frames=input_video)
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
rendered_frames = self.preprocess_images(rendered_frames)
|
||||
rendered_frames = torch.stack(rendered_frames, dim=2)
|
||||
target_latents = self.encode_video(rendered_frames).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.scheduler.return_to_timestep(self.scheduler.timesteps[progress_id], latents, target_latents)
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def check(self, dit: HunyuanVideoDiT, img, vec):
|
||||
img_ = img.clone()
|
||||
vec_ = vec.clone()
|
||||
img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
|
||||
normed_inp = dit.double_blocks[0].component_a.norm1(img_)
|
||||
modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = img.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def lets_dance_hunyuan_video(
|
||||
dit: HunyuanVideoDiT,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = dit.img_in(x)
|
||||
txt = dit.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, img, vec)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if tea_cache_update:
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
img = x[:, :-256]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
img = dit.final_layer(img, vec)
|
||||
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
289
diffsynth/pipelines/omnigen_image.py
Normal file
289
diffsynth/pipelines/omnigen_image.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from ..models.omnigen import OmniGenTransformer
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..prompters.omnigen_prompter import OmniGenPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from typing import Optional, Dict, Any, Tuple, List
|
||||
from transformers.cache_utils import DynamicCache
|
||||
import torch, os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class OmniGenCache(DynamicCache):
|
||||
def __init__(self,
|
||||
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
|
||||
if not torch.cuda.is_available():
|
||||
print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
|
||||
offload_kv_cache = False
|
||||
raise RuntimeError("OffloadedCache can only be used with a GPU")
|
||||
super().__init__()
|
||||
self.original_device = []
|
||||
self.prefetch_stream = torch.cuda.Stream()
|
||||
self.num_tokens_for_img = num_tokens_for_img
|
||||
self.offload_kv_cache = offload_kv_cache
|
||||
|
||||
def prefetch_layer(self, layer_idx: int):
|
||||
"Starts prefetching the next layer cache"
|
||||
if layer_idx < len(self):
|
||||
with torch.cuda.stream(self.prefetch_stream):
|
||||
# Prefetch next layer tensors to GPU
|
||||
device = self.original_device[layer_idx]
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
|
||||
|
||||
|
||||
def evict_previous_layer(self, layer_idx: int):
|
||||
"Moves the previous layer cache to the CPU"
|
||||
if len(self) > 2:
|
||||
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
|
||||
if layer_idx == 0:
|
||||
prev_layer_idx = -1
|
||||
else:
|
||||
prev_layer_idx = (layer_idx - 1) % len(self)
|
||||
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
||||
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
||||
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
||||
if layer_idx < len(self):
|
||||
if self.offload_kv_cache:
|
||||
# Evict the previous layer if necessary
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.evict_previous_layer(layer_idx)
|
||||
# Load current layer cache to its original device if not already there
|
||||
original_device = self.original_device[layer_idx]
|
||||
# self.prefetch_stream.synchronize(original_device)
|
||||
torch.cuda.synchronize(self.prefetch_stream)
|
||||
key_tensor = self.key_cache[layer_idx]
|
||||
value_tensor = self.value_cache[layer_idx]
|
||||
|
||||
# Prefetch the next layer
|
||||
self.prefetch_layer((layer_idx + 1) % len(self))
|
||||
else:
|
||||
key_tensor = self.key_cache[layer_idx]
|
||||
value_tensor = self.value_cache[layer_idx]
|
||||
return (key_tensor, value_tensor)
|
||||
else:
|
||||
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
||||
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
# Update the cache
|
||||
if len(self.key_cache) < layer_idx:
|
||||
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
|
||||
elif len(self.key_cache) == layer_idx:
|
||||
# only cache the states for condition tokens
|
||||
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
|
||||
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
|
||||
|
||||
# Update the number of seen tokens
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
self.original_device.append(key_states.device)
|
||||
if self.offload_kv_cache:
|
||||
self.evict_previous_layer(layer_idx)
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
else:
|
||||
# only cache the states for condition tokens
|
||||
key_tensor, value_tensor = self[layer_idx]
|
||||
k = torch.cat([key_tensor, key_states], dim=-2)
|
||||
v = torch.cat([value_tensor, value_states], dim=-2)
|
||||
return k, v
|
||||
|
||||
|
||||
|
||||
class OmnigenImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
|
||||
# models
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
self.transformer: OmniGenTransformer = None
|
||||
self.prompter: OmniGenPrompter = None
|
||||
self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.transformer
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
# Main models
|
||||
self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
|
||||
pipe = OmnigenImagePipeline(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, prompt_refiner_classes=[])
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
|
||||
return latents
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, clip_skip=1, positive=True):
|
||||
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
|
||||
return {"encoder_hidden_states": prompt_emb}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
return {}
|
||||
|
||||
|
||||
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
||||
if isinstance(position_ids, list):
|
||||
for i in range(len(position_ids)):
|
||||
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
|
||||
else:
|
||||
position_ids = position_ids[:, -(num_tokens_for_img+1):]
|
||||
return position_ids
|
||||
|
||||
|
||||
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
|
||||
if isinstance(attention_mask, list):
|
||||
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
|
||||
return attention_mask[..., -(num_tokens_for_img+1):, :]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
reference_images=[],
|
||||
cfg_scale=2.0,
|
||||
image_cfg_scale=2.0,
|
||||
use_kv_cache=True,
|
||||
offload_kv_cache=True,
|
||||
input_image=None,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
latents = latents.repeat(3, 1, 1, 1)
|
||||
|
||||
# Encode prompts
|
||||
input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
|
||||
|
||||
# Encode images
|
||||
reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
|
||||
|
||||
# Pack all parameters
|
||||
model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
|
||||
input_img_latents=reference_latents,
|
||||
input_image_sizes=input_data['input_image_sizes'],
|
||||
attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
|
||||
position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
|
||||
cfg_scale=cfg_scale,
|
||||
img_cfg_scale=image_cfg_scale,
|
||||
use_img_cfg=True,
|
||||
use_kv_cache=use_kv_cache,
|
||||
offload_model=False,
|
||||
)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['transformer'])
|
||||
cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
|
||||
|
||||
# Forward
|
||||
noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Update KV cache
|
||||
if progress_id == 0 and use_kv_cache:
|
||||
num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
|
||||
if isinstance(cache, list):
|
||||
model_kwargs['input_ids'] = [None] * len(cache)
|
||||
else:
|
||||
model_kwargs['input_ids'] = None
|
||||
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
|
||||
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
|
||||
|
||||
# UI
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
del cache
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# offload all models
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
@@ -10,7 +10,7 @@ from tqdm import tqdm
|
||||
class SD3ImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.prompter = SD3Prompter()
|
||||
# models
|
||||
@@ -59,9 +59,9 @@ class SD3ImagePipeline(BasePipeline):
|
||||
return image
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
|
||||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive
|
||||
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
||||
|
||||
@@ -84,12 +84,16 @@ class SD3ImagePipeline(BasePipeline):
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
t5_sequence_length=77,
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -101,16 +105,16 @@ class SD3ImagePipeline(BasePipeline):
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(['dit'])
|
||||
|
||||
@@ -108,9 +108,12 @@ class SDImagePipeline(BasePipeline):
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -122,10 +125,10 @@ class SDImagePipeline(BasePipeline):
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(['text_encoder'])
|
||||
|
||||
@@ -166,9 +166,12 @@ class SDVideoPipeline(SDImagePipeline):
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters, batch size ...
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
other_kwargs = {
|
||||
@@ -182,9 +185,9 @@ class SDVideoPipeline(SDImagePipeline):
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
|
||||
@@ -9,6 +9,7 @@ from .dancer import lets_dance_xl
|
||||
from typing import List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
|
||||
|
||||
|
||||
@@ -103,7 +104,8 @@ class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
def prepare_extra_input(self, latents=None):
|
||||
height, width = latents.shape[2] * 8, latents.shape[3] * 8
|
||||
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
|
||||
return {"add_time_id": add_time_id}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -129,9 +131,12 @@ class SDXLImagePipeline(BasePipeline):
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -143,10 +148,10 @@ class SDXLImagePipeline(BasePipeline):
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.encode_image(image, **tiler_kwargs)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
|
||||
|
||||
@@ -120,9 +120,12 @@ class SDXLVideoPipeline(SDXLImagePipeline):
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Tiler parameters, batch size ...
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -131,9 +134,9 @@ class SDXLVideoPipeline(SDXLImagePipeline):
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
|
||||
209
diffsynth/pipelines/step_video.py
Normal file
209
diffsynth/pipelines/step_video.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from ..models import ModelManager
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
|
||||
from ..models.stepvideo_text_encoder import STEP1TextEncoder
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import StepVideoPrompter
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings
|
||||
from ..models.stepvideo_dit import RMSNorm
|
||||
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
|
||||
|
||||
|
||||
|
||||
class StepVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0.0, extra_one_step=True, shift=13.0, reverse_sigmas=True, num_train_timesteps=1)
|
||||
self.prompter = StepVideoPrompter()
|
||||
self.text_encoder_1: HunyuanDiTCLIPTextEncoder = None
|
||||
self.text_encoder_2: STEP1TextEncoder = None
|
||||
self.dit: StepVideoModel = None
|
||||
self.vae: StepVideoVAE = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae']
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
BertEmbeddings: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=torch.float32,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
CausalConv: AutoWrappedModule,
|
||||
CausalConvAfterNorm: AutoWrappedModule,
|
||||
Upsample2D: AutoWrappedModule,
|
||||
BaseGroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.text_encoder_1 = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
|
||||
self.text_encoder_2 = model_manager.fetch_model("stepvideo_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("stepvideo_dit")
|
||||
self.vae = model_manager.fetch_model("stepvideo_vae")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = StepVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
clip_embeds, llm_embeds, llm_mask = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
|
||||
clip_embeds = clip_embeds.to(dtype=self.torch_dtype, device=self.device)
|
||||
llm_embeds = llm_embeds.to(dtype=self.torch_dtype, device=self.device)
|
||||
llm_mask = llm_mask.to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"encoder_hidden_states_2": clip_embeds, "encoder_hidden_states": llm_embeds, "encoder_attention_mask": llm_mask}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
height=544,
|
||||
width=992,
|
||||
num_frames=204,
|
||||
cfg_scale=9.0,
|
||||
num_inference_steps=30,
|
||||
tiled=True,
|
||||
tile_size=(34, 34),
|
||||
tile_stride=(16, 16),
|
||||
smooth_scale=0.6,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Initialize noise
|
||||
latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
@@ -49,9 +49,9 @@ class SVDVideoPipeline(BasePipeline):
|
||||
return image_emb
|
||||
|
||||
|
||||
def encode_image_with_vae(self, image, noise_aug_strength):
|
||||
def encode_image_with_vae(self, image, noise_aug_strength, seed=None):
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device)
|
||||
noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
image = image + noise_aug_strength * noise
|
||||
image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
|
||||
return image_emb
|
||||
@@ -126,14 +126,17 @@ class SVDVideoPipeline(BasePipeline):
|
||||
num_inference_steps=20,
|
||||
post_normalize=True,
|
||||
contrast_enhance_scale=1.2,
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
|
||||
noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
if denoising_strength == 1.0:
|
||||
latents = noise.clone()
|
||||
else:
|
||||
@@ -147,7 +150,7 @@ class SVDVideoPipeline(BasePipeline):
|
||||
# Encode image
|
||||
image_emb_clip_posi = self.encode_image_with_clip(input_image)
|
||||
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
|
||||
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames)
|
||||
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames)
|
||||
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
|
||||
|
||||
# Prepare classifier-free guidance
|
||||
|
||||
@@ -7,3 +7,6 @@ from .kolors_prompter import KolorsPrompter
|
||||
from .flux_prompter import FluxPrompter
|
||||
from .omost import OmostPromter
|
||||
from .cog_prompter import CogPrompter
|
||||
from .hunyuan_video_prompter import HunyuanVideoPrompter
|
||||
from .stepvideo_prompter import StepVideoPrompter
|
||||
from .wanx_prompter import WanXPrompter
|
||||
@@ -1,5 +1,6 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
|
||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
import os, torch
|
||||
|
||||
@@ -19,11 +20,11 @@ class FluxPrompter(BasePrompter):
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
|
||||
self.text_encoder_1: FluxTextEncoder1 = None
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: FluxTextEncoder2 = None
|
||||
|
||||
|
||||
def fetch_models(self, text_encoder_1: FluxTextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
@@ -36,7 +37,7 @@ class FluxPrompter(BasePrompter):
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids.to(device)
|
||||
_, pooled_prompt_emb = text_encoder(input_ids)
|
||||
pooled_prompt_emb, _ = text_encoder(input_ids)
|
||||
return pooled_prompt_emb
|
||||
|
||||
|
||||
@@ -49,8 +50,6 @@ class FluxPrompter(BasePrompter):
|
||||
truncation=True,
|
||||
).input_ids.to(device)
|
||||
prompt_emb = text_encoder(input_ids)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
@@ -58,7 +57,8 @@ class FluxPrompter(BasePrompter):
|
||||
self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
device="cuda",
|
||||
t5_sequence_length=512,
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
@@ -66,7 +66,7 @@ class FluxPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
||||
|
||||
# T5
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, 256, device)
|
||||
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
|
||||
|
||||
# text_ids
|
||||
text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
|
||||
|
||||
143
diffsynth/prompters/hunyuan_video_prompter.py
Normal file
143
diffsynth/prompters/hunyuan_video_prompter.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
import os, torch
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE = {
|
||||
"dit-llm-encode": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE,
|
||||
"crop_start": 36,
|
||||
},
|
||||
"dit-llm-encode-video": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||
"crop_start": 95,
|
||||
},
|
||||
}
|
||||
|
||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||
|
||||
|
||||
class HunyuanVideoPrompter(BasePrompter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_1_path=None,
|
||||
tokenizer_2_path=None,
|
||||
):
|
||||
if tokenizer_1_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_1_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_video/tokenizer_1")
|
||||
if tokenizer_2_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
def apply_text_to_template(self, text, template):
|
||||
assert isinstance(template, str)
|
||||
if isinstance(text, list):
|
||||
return [self.apply_text_to_template(text_) for text_ in text]
|
||||
elif isinstance(text, str):
|
||||
# Will send string to tokenizer. Used for llm
|
||||
return template.format(text)
|
||||
else:
|
||||
raise TypeError(f"Unsupported prompt type: {type(text)}")
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, max_length, device):
|
||||
tokenized_result = self.tokenizer_1(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = tokenized_result.input_ids.to(device)
|
||||
attention_mask = tokenized_result.attention_mask.to(device)
|
||||
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
|
||||
|
||||
def encode_prompt_using_llm(self,
|
||||
prompt,
|
||||
max_length,
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True):
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
|
||||
|
||||
# crop out
|
||||
if crop_start > 0:
|
||||
last_hidden_state = last_hidden_state[:, crop_start:]
|
||||
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
clip_sequence_length=77,
|
||||
llm_sequence_length=256,
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# apply template
|
||||
if use_template:
|
||||
template = self.prompt_template_video if data_type == 'video' else self.prompt_template
|
||||
prompt_formated = self.apply_text_to_template(prompt, template['template'])
|
||||
else:
|
||||
prompt_formated = prompt
|
||||
# Text encoder
|
||||
if data_type == 'video':
|
||||
crop_start = self.prompt_template_video.get("crop_start", 0)
|
||||
else:
|
||||
crop_start = self.prompt_template.get("crop_start", 0)
|
||||
|
||||
# CLIP
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
@@ -245,6 +245,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
padding_side: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
356
diffsynth/prompters/omnigen_prompter.py
Normal file
356
diffsynth/prompters/omnigen_prompter.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from transformers import AutoTokenizer
|
||||
from huggingface_hub import snapshot_download
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def crop_arr(pil_image, max_image_size):
|
||||
while min(*pil_image.size) >= 2 * max_image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
if max(*pil_image.size) > max_image_size:
|
||||
scale = max_image_size / max(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
if min(*pil_image.size) < 16:
|
||||
scale = 16 / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y1 = (arr.shape[0] % 16) // 2
|
||||
crop_y2 = arr.shape[0] % 16 - crop_y1
|
||||
|
||||
crop_x1 = (arr.shape[1] % 16) // 2
|
||||
crop_x2 = arr.shape[1] % 16 - crop_x1
|
||||
|
||||
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
|
||||
return Image.fromarray(arr)
|
||||
|
||||
|
||||
|
||||
class OmniGenPrompter:
|
||||
def __init__(self,
|
||||
text_tokenizer,
|
||||
max_image_size: int=1024):
|
||||
self.text_tokenizer = text_tokenizer
|
||||
self.max_image_size = max_image_size
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
||||
])
|
||||
|
||||
self.collator = OmniGenCollator()
|
||||
self.separate_collator = OmniGenSeparateCollator()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name):
|
||||
if not os.path.exists(model_name):
|
||||
cache_folder = os.getenv('HF_HUB_CACHE')
|
||||
model_name = snapshot_download(repo_id=model_name,
|
||||
cache_dir=cache_folder,
|
||||
allow_patterns="*.json")
|
||||
text_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
return cls(text_tokenizer)
|
||||
|
||||
|
||||
def process_image(self, image):
|
||||
return self.image_transform(image)
|
||||
|
||||
def process_multi_modal_prompt(self, text, input_images):
|
||||
text = self.add_prefix_instruction(text)
|
||||
if input_images is None or len(input_images) == 0:
|
||||
model_inputs = self.text_tokenizer(text)
|
||||
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
|
||||
|
||||
pattern = r"<\|image_\d+\|>"
|
||||
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
|
||||
|
||||
for i in range(1, len(prompt_chunks)):
|
||||
if prompt_chunks[i][0] == 1:
|
||||
prompt_chunks[i] = prompt_chunks[i][1:]
|
||||
|
||||
image_tags = re.findall(pattern, text)
|
||||
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
||||
|
||||
unique_image_ids = sorted(list(set(image_ids)))
|
||||
assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
||||
# total images must be the same as the number of image tags
|
||||
assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
||||
|
||||
input_images = [input_images[x-1] for x in image_ids]
|
||||
|
||||
all_input_ids = []
|
||||
img_inx = []
|
||||
idx = 0
|
||||
for i in range(len(prompt_chunks)):
|
||||
all_input_ids.extend(prompt_chunks[i])
|
||||
if i != len(prompt_chunks) -1:
|
||||
start_inx = len(all_input_ids)
|
||||
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
|
||||
img_inx.append([start_inx, start_inx+size])
|
||||
all_input_ids.extend([0]*size)
|
||||
|
||||
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
|
||||
|
||||
|
||||
def add_prefix_instruction(self, prompt):
|
||||
user_prompt = '<|user|>\n'
|
||||
generation_prompt = 'Generate an image according to the following instructions\n'
|
||||
assistant_prompt = '<|assistant|>\n<|diffusion|>'
|
||||
prompt_suffix = "<|end|>\n"
|
||||
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
|
||||
return prompt
|
||||
|
||||
|
||||
def __call__(self,
|
||||
instructions: List[str],
|
||||
input_images: List[List[str]] = None,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
|
||||
use_img_cfg: bool = True,
|
||||
separate_cfg_input: bool = False,
|
||||
use_input_image_size_as_output: bool=False,
|
||||
) -> Dict:
|
||||
|
||||
if input_images is None:
|
||||
use_img_cfg = False
|
||||
if isinstance(instructions, str):
|
||||
instructions = [instructions]
|
||||
input_images = [input_images]
|
||||
|
||||
input_data = []
|
||||
for i in range(len(instructions)):
|
||||
cur_instruction = instructions[i]
|
||||
cur_input_images = None if input_images is None else input_images[i]
|
||||
if cur_input_images is not None and len(cur_input_images) > 0:
|
||||
cur_input_images = [self.process_image(x) for x in cur_input_images]
|
||||
else:
|
||||
cur_input_images = None
|
||||
assert "<img><|image_1|></img>" not in cur_instruction
|
||||
|
||||
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
|
||||
|
||||
|
||||
neg_mllm_input, img_cfg_mllm_input = None, None
|
||||
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
|
||||
if use_img_cfg:
|
||||
if cur_input_images is not None and len(cur_input_images) >= 1:
|
||||
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
|
||||
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
|
||||
else:
|
||||
img_cfg_mllm_input = neg_mllm_input
|
||||
|
||||
if use_input_image_size_as_output:
|
||||
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
|
||||
else:
|
||||
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
||||
|
||||
if separate_cfg_input:
|
||||
return self.separate_collator(input_data)
|
||||
return self.collator(input_data)
|
||||
|
||||
|
||||
|
||||
|
||||
class OmniGenCollator:
|
||||
def __init__(self, pad_token_id=2, hidden_size=3072):
|
||||
self.pad_token_id = pad_token_id
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def create_position(self, attention_mask, num_tokens_for_output_images):
|
||||
position_ids = []
|
||||
text_length = attention_mask.size(-1)
|
||||
img_length = max(num_tokens_for_output_images)
|
||||
for mask in attention_mask:
|
||||
temp_l = torch.sum(mask)
|
||||
temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
|
||||
position_ids.append(temp_position)
|
||||
return torch.LongTensor(position_ids)
|
||||
|
||||
def create_mask(self, attention_mask, num_tokens_for_output_images):
|
||||
extended_mask = []
|
||||
padding_images = []
|
||||
text_length = attention_mask.size(-1)
|
||||
img_length = max(num_tokens_for_output_images)
|
||||
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
|
||||
inx = 0
|
||||
for mask in attention_mask:
|
||||
temp_l = torch.sum(mask)
|
||||
pad_l = text_length - temp_l
|
||||
|
||||
temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
|
||||
|
||||
image_mask = torch.zeros(size=(temp_l+1, img_length))
|
||||
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
|
||||
|
||||
image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
|
||||
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
|
||||
|
||||
if pad_l > 0:
|
||||
pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
|
||||
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
|
||||
|
||||
pad_mask = torch.ones(size=(pad_l, seq_len))
|
||||
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
|
||||
|
||||
true_img_length = num_tokens_for_output_images[inx]
|
||||
pad_img_length = img_length - true_img_length
|
||||
if pad_img_length > 0:
|
||||
temp_mask[:, -pad_img_length:] = 0
|
||||
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
|
||||
else:
|
||||
temp_padding_imgs = None
|
||||
|
||||
extended_mask.append(temp_mask.unsqueeze(0))
|
||||
padding_images.append(temp_padding_imgs)
|
||||
inx += 1
|
||||
return torch.cat(extended_mask, dim=0), padding_images
|
||||
|
||||
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
|
||||
for b_inx in image_sizes.keys():
|
||||
for start_inx, end_inx in image_sizes[b_inx]:
|
||||
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
|
||||
|
||||
return attention_mask
|
||||
|
||||
def pad_input_ids(self, input_ids, image_sizes):
|
||||
max_l = max([len(x) for x in input_ids])
|
||||
padded_ids = []
|
||||
attention_mask = []
|
||||
new_image_sizes = []
|
||||
|
||||
for i in range(len(input_ids)):
|
||||
temp_ids = input_ids[i]
|
||||
temp_l = len(temp_ids)
|
||||
pad_l = max_l - temp_l
|
||||
if pad_l == 0:
|
||||
attention_mask.append([1]*max_l)
|
||||
padded_ids.append(temp_ids)
|
||||
else:
|
||||
attention_mask.append([0]*pad_l+[1]*temp_l)
|
||||
padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
|
||||
|
||||
if i in image_sizes:
|
||||
new_inx = []
|
||||
for old_inx in image_sizes[i]:
|
||||
new_inx.append([x+pad_l for x in old_inx])
|
||||
image_sizes[i] = new_inx
|
||||
|
||||
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
|
||||
|
||||
|
||||
def process_mllm_input(self, mllm_inputs, target_img_size):
|
||||
num_tokens_for_output_images = []
|
||||
for img_size in target_img_size:
|
||||
num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
|
||||
|
||||
pixel_values, image_sizes = [], {}
|
||||
b_inx = 0
|
||||
for x in mllm_inputs:
|
||||
if x['pixel_values'] is not None:
|
||||
pixel_values.extend(x['pixel_values'])
|
||||
for size in x['image_sizes']:
|
||||
if b_inx not in image_sizes:
|
||||
image_sizes[b_inx] = [size]
|
||||
else:
|
||||
image_sizes[b_inx].append(size)
|
||||
b_inx += 1
|
||||
pixel_values = [x.unsqueeze(0) for x in pixel_values]
|
||||
|
||||
|
||||
input_ids = [x['input_ids'] for x in mllm_inputs]
|
||||
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
|
||||
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
|
||||
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
|
||||
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
|
||||
|
||||
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
|
||||
|
||||
|
||||
def __call__(self, features):
|
||||
mllm_inputs = [f[0] for f in features]
|
||||
cfg_mllm_inputs = [f[1] for f in features]
|
||||
img_cfg_mllm_input = [f[2] for f in features]
|
||||
target_img_size = [f[3] for f in features]
|
||||
|
||||
|
||||
if img_cfg_mllm_input[0] is not None:
|
||||
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
|
||||
target_img_size = target_img_size + target_img_size + target_img_size
|
||||
else:
|
||||
mllm_inputs = mllm_inputs + cfg_mllm_inputs
|
||||
target_img_size = target_img_size + target_img_size
|
||||
|
||||
|
||||
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
||||
|
||||
data = {"input_ids": all_padded_input_ids,
|
||||
"attention_mask": all_attention_mask,
|
||||
"position_ids": all_position_ids,
|
||||
"input_pixel_values": all_pixel_values,
|
||||
"input_image_sizes": all_image_sizes,
|
||||
"padding_images": all_padding_images,
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class OmniGenSeparateCollator(OmniGenCollator):
|
||||
def __call__(self, features):
|
||||
mllm_inputs = [f[0] for f in features]
|
||||
cfg_mllm_inputs = [f[1] for f in features]
|
||||
img_cfg_mllm_input = [f[2] for f in features]
|
||||
target_img_size = [f[3] for f in features]
|
||||
|
||||
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
|
||||
|
||||
|
||||
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
||||
all_padded_input_ids.append(padded_input_ids)
|
||||
all_attention_mask.append(attention_mask)
|
||||
all_position_ids.append(position_ids)
|
||||
all_pixel_values.append(pixel_values)
|
||||
all_image_sizes.append(image_sizes)
|
||||
all_padding_images.append(padding_images)
|
||||
|
||||
if cfg_mllm_inputs[0] is not None:
|
||||
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
|
||||
all_padded_input_ids.append(padded_input_ids)
|
||||
all_attention_mask.append(attention_mask)
|
||||
all_position_ids.append(position_ids)
|
||||
all_pixel_values.append(pixel_values)
|
||||
all_image_sizes.append(image_sizes)
|
||||
all_padding_images.append(padding_images)
|
||||
if img_cfg_mllm_input[0] is not None:
|
||||
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
|
||||
all_padded_input_ids.append(padded_input_ids)
|
||||
all_attention_mask.append(attention_mask)
|
||||
all_position_ids.append(position_ids)
|
||||
all_pixel_values.append(pixel_values)
|
||||
all_image_sizes.append(image_sizes)
|
||||
all_padding_images.append(padding_images)
|
||||
|
||||
data = {"input_ids": all_padded_input_ids,
|
||||
"attention_mask": all_attention_mask,
|
||||
"position_ids": all_position_ids,
|
||||
"input_pixel_values": all_pixel_values,
|
||||
"input_image_sizes": all_image_sizes,
|
||||
"padding_images": all_padding_images,
|
||||
}
|
||||
return data
|
||||
@@ -67,7 +67,8 @@ class SD3Prompter(BasePrompter):
|
||||
self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
device="cuda",
|
||||
t5_sequence_length=77,
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
@@ -77,9 +78,9 @@ class SD3Prompter(BasePrompter):
|
||||
|
||||
# T5
|
||||
if self.text_encoder_3 is None:
|
||||
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], t5_sequence_length, 4096), dtype=prompt_emb_1.dtype, device=device)
|
||||
else:
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
|
||||
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, t5_sequence_length, device)
|
||||
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
||||
|
||||
# Merge
|
||||
|
||||
56
diffsynth/prompters/stepvideo_prompter.py
Normal file
56
diffsynth/prompters/stepvideo_prompter.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
|
||||
from ..models.stepvideo_text_encoder import STEP1TextEncoder
|
||||
from transformers import BertTokenizer
|
||||
import os, torch
|
||||
|
||||
|
||||
class StepVideoPrompter(BasePrompter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_1_path=None,
|
||||
):
|
||||
if tokenizer_1_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_1_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_1_path)
|
||||
|
||||
def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = None, text_encoder_2: STEP1TextEncoder = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, max_length, device):
|
||||
text_inputs = self.tokenizer_1(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
prompt_embeds = self.text_encoder_1(
|
||||
text_inputs.input_ids.to(device),
|
||||
attention_mask=text_inputs.attention_mask.to(device),
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt_using_llm(self, prompt, max_length, device):
|
||||
y, y_mask = self.text_encoder_2(prompt, max_length=max_length, device=device)
|
||||
return y, y_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda"):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
clip_embeds = self.encode_prompt_using_clip(prompt, max_length=77, device=device)
|
||||
llm_embeds, llm_mask = self.encode_prompt_using_llm(prompt, max_length=320, device=device)
|
||||
|
||||
llm_mask = torch.nn.functional.pad(llm_mask, (clip_embeds.shape[1], 0), value=1)
|
||||
|
||||
return clip_embeds, llm_embeds, llm_mask
|
||||
103
diffsynth/prompters/wanx_prompter.py
Normal file
103
diffsynth/prompters/wanx_prompter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.wanx_text_encoder import WanXTextEncoder
|
||||
from transformers import AutoTokenizer
|
||||
import os, torch
|
||||
import ftfy
|
||||
import html
|
||||
import string
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
|
||||
class WanXPrompter(BasePrompter):
|
||||
|
||||
def __init__(self, tokenizer_path=None, text_len=512):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
||||
self.text_encoder = None
|
||||
|
||||
def fetch_models(self, text_encoder: WanXTextEncoder = None):
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def encode_prompt(self, prompt, device="cuda"):
|
||||
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
||||
ids = ids.to(device)
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_emb = self.text_encoder(ids, mask)
|
||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
||||
return prompt_emb
|
||||
|
||||
@@ -10,7 +10,7 @@ class ContinuousODEScheduler():
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
|
||||
ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
|
||||
min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
|
||||
max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
|
||||
|
||||
@@ -38,7 +38,7 @@ class EnhancedDDIMScheduler():
|
||||
return alphas_bar
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
|
||||
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
|
||||
# The timesteps are aligned to 999...0, which is different from other implementations,
|
||||
# but I think this implementation is more reasonable in theory.
|
||||
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||
@@ -99,3 +99,7 @@ class EnhancedDDIMScheduler():
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
|
||||
|
||||
def training_weight(self, timestep):
|
||||
return 1.0
|
||||
|
||||
@@ -4,19 +4,35 @@ import torch
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002):
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.reverse_sigmas = reverse_sigmas
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
if self.reverse_sigmas:
|
||||
self.sigmas = 1 - self.sigmas
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
if training:
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
@@ -25,7 +41,7 @@ class FlowMatchScheduler():
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 0
|
||||
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
@@ -33,8 +49,12 @@ class FlowMatchScheduler():
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
# This scheduler doesn't support this function.
|
||||
pass
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
model_output = (sample - sample_stablized) / sigma
|
||||
return model_output
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
@@ -49,3 +69,9 @@ class FlowMatchScheduler():
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
|
||||
|
||||
def training_weight(self, timestep):
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
weights = self.linear_timesteps_weights[timestep_id]
|
||||
return weights
|
||||
|
||||
48895
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
Normal file
48895
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
49410
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
Normal file
49410
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|begin_of_text|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|end_of_text|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
1251020
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
Normal file
1251020
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ from peft import LoraConfig, inject_adapter_in_model
|
||||
import torch, os
|
||||
from ..data.simple_text_image import TextImageDataset
|
||||
from modelscope.hub.api import HubApi
|
||||
from ..models.utils import load_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -11,11 +12,14 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
self,
|
||||
learning_rate=1e-4,
|
||||
use_gradient_checkpointing=True,
|
||||
state_dict_converter=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Set parameters
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.state_dict_converter = state_dict_converter
|
||||
self.lora_alpha = None
|
||||
|
||||
|
||||
def load_models(self):
|
||||
@@ -30,12 +34,16 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
self.pipe.denoising_model().train()
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"):
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, state_dict_converter=None):
|
||||
# Add LoRA to UNet
|
||||
self.lora_alpha = lora_alpha
|
||||
if init_lora_weights == "kaiming":
|
||||
init_lora_weights = True
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
init_lora_weights=init_lora_weights,
|
||||
target_modules=lora_target_modules.split(","),
|
||||
)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
@@ -44,6 +52,17 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# Lora pretrained lora weights
|
||||
if pretrained_lora_path is not None:
|
||||
state_dict = load_state_dict(pretrained_lora_path)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
all_keys = [i for i, _ in model.named_parameters()]
|
||||
num_updated_keys = len(all_keys) - len(missing_keys)
|
||||
num_unexpected_keys = len(unexpected_keys)
|
||||
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
@@ -52,7 +71,10 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
if "latents" in batch:
|
||||
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
|
||||
else:
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
@@ -65,7 +87,8 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, training_target)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
@@ -83,9 +106,13 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.denoising_model().state_dict()
|
||||
lora_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
checkpoint[name] = param
|
||||
lora_state_dict[name] = param
|
||||
if self.state_dict_converter is not None:
|
||||
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||
checkpoint.update(lora_state_dict)
|
||||
|
||||
|
||||
|
||||
@@ -173,6 +200,13 @@ def add_general_parsers(parser):
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_lora_weights",
|
||||
type=str,
|
||||
default="kaiming",
|
||||
choices=["gaussian", "kaiming"],
|
||||
help="The initializing method of LoRA weight.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
@@ -210,6 +244,12 @@ def add_general_parsers(parser):
|
||||
default=None,
|
||||
help="Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_lora_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
1
diffsynth/vram_management/__init__.py
Normal file
1
diffsynth/vram_management/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .layers import *
|
||||
95
diffsynth/vram_management/layers.py
Normal file
95
diffsynth/vram_management/layers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch, copy
|
||||
from ..models.utils import init_weights_on_device
|
||||
|
||||
|
||||
def cast_to(weight, dtype, device):
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
|
||||
class AutoWrappedModule(torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
super().__init__()
|
||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
module = self.module
|
||||
else:
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
self.bias = module.bias
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||
for name, module in model.named_children():
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
num_param = sum(p.numel() for p in module.parameters())
|
||||
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
||||
module_config_ = overflow_module_config
|
||||
else:
|
||||
module_config_ = module_config
|
||||
module_ = target_module(module, **module_config_)
|
||||
setattr(model, name, module_)
|
||||
total_num_param += num_param
|
||||
break
|
||||
else:
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
||||
return total_num_param
|
||||
|
||||
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
||||
model.vram_management_enabled = True
|
||||
|
||||
43
examples/ArtAug/README.md
Normal file
43
examples/ArtAug/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# FLUX Aesthetics Enhancement LoRA
|
||||
|
||||
## Introduction
|
||||
|
||||
This is a LoRA model trained for FLUX.1-dev, which enhances the aesthetic quality of images generated by the model. The improvements include, but are not limited to: rich details, beautiful lighting and shadows, aesthetic composition, and clear visuals. This model does not require any trigger words.
|
||||
|
||||
* Paper: https://arxiv.org/abs/2412.12888
|
||||
* Github: https://github.com/modelscope/DiffSynth-Studio
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
||||
* Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
|
||||
|
||||
## Methodology
|
||||
|
||||

|
||||
|
||||
The ArtAug project is inspired by reasoning approaches like GPT-o1, which rely on model interaction and self-correction. We developed a framework aimed at enhancing the capabilities of image generation models through interaction with image understanding models. The training process of ArtAug consists of the following steps:
|
||||
|
||||
1. **Synthesis-Understanding Interaction**: After generating an image using the image generation model, we employ a multimodal large language model (Qwen2-VL-72B) to analyze the image content and provide suggestions for modifications, which then lead to the regeneration of a higher quality image.
|
||||
|
||||
2. **Data Generation and Filtering**: Interactive generation involves long inference times and sometimes produce poor image content. Therefore, we generate a large batch of image pairs offline, filter them, and use them for subsequent training.
|
||||
|
||||
3. **Differential Training**: We apply differential training techniques to train a LoRA model, enabling it to learn the differences between images before and after enhancement, rather than directly training on the dataset of enhanced images.
|
||||
|
||||
4. **Iterative Enhancement**: The trained LoRA model is fused into the base model, and the entire process is repeated multiple times with the fused model until the interaction algorithm no longer provides significant enhancements. The LoRA models produced in each iteration are combined to produce this final model.
|
||||
|
||||
This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1[dev], leading to an improvement in the quality of generated images.
|
||||
|
||||
## Usage
|
||||
|
||||
Please see [./artaug_flux.py](./artaug_flux.py) for more details.
|
||||
|
||||
Since this model is encapsulated in the universal FLUX LoRA format, it can be loaded by most LoRA loaders, allowing you to integrate this LoRA model into your own workflow.
|
||||
|
||||
## Examples
|
||||
|
||||
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
14
examples/ArtAug/artaug_flux.py
Normal file
14
examples/ArtAug/artaug_flux.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
|
||||
lora_path = download_customized_models(
|
||||
model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1",
|
||||
origin_file_path="merged_lora.safetensors",
|
||||
local_dir="models/lora"
|
||||
)[0]
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(lora_path, lora_alpha=1.0)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
image = pipe(prompt="a house", seed=0)
|
||||
image.save("image_artaug.jpg")
|
||||
91
examples/ControlNet/README.md
Normal file
91
examples/ControlNet/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# ControlNet
|
||||
|
||||
We provide extensive ControlNet support. Taking the FLUX model as an example, we support many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation.
|
||||
|
||||
These examples are in [`flux_controlnet.py`](./flux_controlnet.py).
|
||||
|
||||
## Canny/Depth/Normal: Structure Control
|
||||
|
||||
Structural control is the most fundamental capability of the ControlNet model. By using Canny to extract edge information, or by utilizing depth maps and normal maps, we can extract the structure of an image, which can then serve as control information during the image generation process.
|
||||
|
||||
Model link: https://modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha
|
||||
|
||||
For example, if we generate an image of a cat and use a model like InstantX/FLUX.1-dev-Controlnet-Union-alpha that supports multiple control conditions, we can simultaneously enable both Canny and Depth controls to transform the environment into a twilight setting.
|
||||
|
||||
|||
|
||||
|-|-|
|
||||
|
||||
The control strength of ControlNet for structure can be adjusted. For example, in the case below, when we move the girl from summer to winter, we can appropriately lower the control strength of ControlNet so that the model will adapt to the content of the image and change her into warm clothes.
|
||||
|
||||
|||
|
||||
|-|-|
|
||||
|
||||
## Upscaler/Tile/Blur: High-Resolution Image Synthesis
|
||||
|
||||
There are many ControlNet models that support high definition, such as:
|
||||
|
||||
Model link: https://modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler, https://modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha, https://modelscope.cn/models/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro
|
||||
|
||||
These models can transform blurry, noisy low-quality images into clear ones. In DiffSynth-Studio, the native high-resolution patch processing technology supported by the framework can overcome the resolution limitations of the models, enabling image generation at resolutions of 2048 or even higher, significantly enhancing the capabilities of these models. In the example below, we can see that in the high-definition image enlarged to 2048 resolution, the cat's fur is rendered in exquisite detail, and the skin texture of the characters is delicate and realistic.
|
||||
|
||||
|||
|
||||
|-|-|
|
||||
|
||||
|||
|
||||
|-|-|
|
||||
|
||||
## Inpaint: Image Restoration
|
||||
|
||||
The Inpaint ControlNet model can repaint specific areas in an image. For example, we can put sunglasses on a cat.
|
||||
|
||||
Model link: https://modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
However, we noticed that the head movements of the cat have changed. If we want to preserve the original structural features, we can use the Canny, Depth, and Normal models. DiffSynth-Studio provides seamless support for ControlNet of different structures. By using a Normal ControlNet, we can ensure that the structure of the image remains unchanged during local redrawing.
|
||||
|
||||
Model link: https://modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Surface-Normals
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
## MultiControlNet+MultiDiffusion: Fine-Grained Control
|
||||
|
||||
DiffSynth-Studio not only supports the simultaneous activation of multiple ControlNet structures, but also allows for the partitioned control of content within an image using different prompts. Additionally, it supports the chunk processing of ultra-high-resolution large images, enabling us to achieve extremely detailed high-level control. Next, we will showcase the creative process behind a beautiful image.
|
||||
|
||||
First, use the prompt "a beautiful Asian woman and a cat on a bed. The woman wears a dress" to generate a cat and a young girl.
|
||||
|
||||

|
||||
|
||||
Then, enable Inpaint ControlNet and Canny ControlNet.
|
||||
|
||||
Model link: https://modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta, https://modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha
|
||||
|
||||
We control the image using two component.
|
||||
|
||||
|Prompt: an orange cat, highly detailed|Prompt: a girl wearing a red camisole|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
Generate!
|
||||
|
||||

|
||||
|
||||
The background is a bit blurry, so we use deblurring LoRA for image-to-image generation.
|
||||
|
||||
Model link: https://modelscope.cn/models/LiblibAI/FLUX.1-dev-LoRA-AntiBlur
|
||||
|
||||

|
||||
|
||||
The entire image is much clearer now. Next, let's use the high-definition model to increase the resolution to 4096*4096!
|
||||
|
||||
Model link: https://modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler
|
||||
|
||||

|
||||
|
||||
Zoom in to see details.
|
||||
|
||||

|
||||
|
||||
Enjoy!
|
||||
299
examples/ControlNet/flux_controlnet.py
Normal file
299
examples/ControlNet/flux_controlnet.py
Normal file
@@ -0,0 +1,299 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, ControlNetConfigUnit, download_models, download_customized_models
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def example_1():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
|
||||
scale=0.7
|
||||
),
|
||||
])
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
height=768, width=768,
|
||||
seed=0
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
controlnet_image=image_1.resize((2048, 2048)),
|
||||
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=1
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
|
||||
|
||||
|
||||
def example_2():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
|
||||
scale=0.7
|
||||
),
|
||||
])
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Chinese girl, delicate skin texture",
|
||||
height=768, width=768,
|
||||
seed=2
|
||||
)
|
||||
image_1.save("image_3.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Chinese girl, delicate skin texture",
|
||||
controlnet_image=image_1.resize((2048, 2048)),
|
||||
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=3
|
||||
)
|
||||
image_2.save("image_4.jpg")
|
||||
|
||||
|
||||
def example_3():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="canny",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="depth",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
])
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a cat is running",
|
||||
height=1024, width=1024,
|
||||
seed=4
|
||||
)
|
||||
image_1.save("image_5.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="sunshine, a cat is running",
|
||||
controlnet_image=image_1,
|
||||
height=1024, width=1024,
|
||||
seed=5
|
||||
)
|
||||
image_2.save("image_6.jpg")
|
||||
|
||||
|
||||
def example_4():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="canny",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="depth",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
])
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, summer",
|
||||
height=1024, width=1024,
|
||||
seed=6
|
||||
)
|
||||
image_1.save("image_7.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, winter",
|
||||
controlnet_image=image_1,
|
||||
height=1024, width=1024,
|
||||
seed=7
|
||||
)
|
||||
image_2.save("image_8.jpg")
|
||||
|
||||
|
||||
|
||||
def example_5():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="inpaint",
|
||||
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
scale=0.9
|
||||
),
|
||||
])
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a cat sitting on a chair",
|
||||
height=1024, width=1024,
|
||||
seed=8
|
||||
)
|
||||
image_1.save("image_9.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[100:350, 350: -300] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask_9.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_image=image_1, controlnet_inpaint_mask=mask,
|
||||
height=1024, width=1024,
|
||||
seed=9
|
||||
)
|
||||
image_2.save("image_10.jpg")
|
||||
|
||||
|
||||
|
||||
def example_6():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[
|
||||
"FLUX.1-dev",
|
||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="inpaint",
|
||||
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
scale=0.9
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="normal",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
|
||||
scale=0.6
|
||||
),
|
||||
])
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian woman looking at the sky, wearing a blue t-shirt.",
|
||||
height=1024, width=1024,
|
||||
seed=10
|
||||
)
|
||||
image_1.save("image_11.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[-400:, 10:-40] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask_11.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian woman looking at the sky, wearing a yellow t-shirt.",
|
||||
controlnet_image=image_1, controlnet_inpaint_mask=mask,
|
||||
height=1024, width=1024,
|
||||
seed=11
|
||||
)
|
||||
image_2.save("image_12.jpg")
|
||||
|
||||
|
||||
def example_7():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=[
|
||||
"FLUX.1-dev",
|
||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
||||
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="inpaint",
|
||||
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
scale=0.9
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="canny",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.5
|
||||
),
|
||||
])
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
|
||||
height=1024, width=1024,
|
||||
seed=100
|
||||
)
|
||||
image_1.save("image_13.jpg")
|
||||
|
||||
mask_global = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask_global = Image.fromarray(mask_global)
|
||||
mask_global.save("mask_13_global.jpg")
|
||||
|
||||
mask_1 = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask_1[300:-100, 30: 450] = 255
|
||||
mask_1 = Image.fromarray(mask_1)
|
||||
mask_1.save("mask_13_1.jpg")
|
||||
|
||||
mask_2 = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask_2[500:-100, -400:] = 255
|
||||
mask_2[-200:-100, -500:-400] = 255
|
||||
mask_2 = Image.fromarray(mask_2)
|
||||
mask_2.save("mask_13_2.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
|
||||
controlnet_image=image_1, controlnet_inpaint_mask=mask_global,
|
||||
local_prompts=["an orange cat, highly detailed", "a girl wearing a red camisole"], masks=[mask_1, mask_2], mask_scales=[10.0, 10.0],
|
||||
height=1024, width=1024,
|
||||
seed=101
|
||||
)
|
||||
image_2.save("image_14.jpg")
|
||||
|
||||
model_manager.load_lora("models/lora/FLUX-dev-lora-AntiBlur.safetensors", lora_alpha=2)
|
||||
image_3 = pipe(
|
||||
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. clear background.",
|
||||
negative_prompt="blur, blurry",
|
||||
input_image=image_2, denoising_strength=0.7,
|
||||
height=1024, width=1024,
|
||||
cfg_scale=2.0, num_inference_steps=50,
|
||||
seed=102
|
||||
)
|
||||
image_3.save("image_15.jpg")
|
||||
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
|
||||
scale=0.7
|
||||
),
|
||||
])
|
||||
image_4 = pipe(
|
||||
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
|
||||
controlnet_image=image_3.resize((2048, 2048)),
|
||||
input_image=image_3.resize((2048, 2048)), denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=103
|
||||
)
|
||||
image_4.save("image_16.jpg")
|
||||
|
||||
image_5 = pipe(
|
||||
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
|
||||
controlnet_image=image_4.resize((4096, 4096)),
|
||||
input_image=image_4.resize((4096, 4096)), denoising_strength=0.99,
|
||||
height=4096, width=4096, tiled=True,
|
||||
seed=104
|
||||
)
|
||||
image_5.save("image_17.jpg")
|
||||
|
||||
|
||||
|
||||
download_models(["Annotators:Depth", "Annotators:Normal"])
|
||||
download_customized_models(
|
||||
model_id="LiblibAI/FLUX.1-dev-LoRA-AntiBlur",
|
||||
origin_file_path="FLUX-dev-lora-AntiBlur.safetensors",
|
||||
local_dir="models/lora"
|
||||
)
|
||||
example_1()
|
||||
example_2()
|
||||
example_3()
|
||||
example_4()
|
||||
example_5()
|
||||
example_6()
|
||||
example_7()
|
||||
447
examples/ControlNet/flux_controlnet_quantization.py
Normal file
447
examples/ControlNet/flux_controlnet_quantization.py
Normal file
@@ -0,0 +1,447 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, ControlNetConfigUnit, download_models, download_customized_models
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def example_1():
|
||||
download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
|
||||
scale=0.7
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
height=768, width=768,
|
||||
seed=0
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a photo of a cat, highly detailed",
|
||||
controlnet_image=image_1.resize((2048, 2048)),
|
||||
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=1
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
|
||||
|
||||
|
||||
def example_2():
|
||||
download_models(["FLUX.1-dev", "jasperai/Flux.1-dev-Controlnet-Upscaler"])
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
|
||||
scale=0.7
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Chinese girl, delicate skin texture",
|
||||
height=768, width=768,
|
||||
seed=2
|
||||
)
|
||||
image_1.save("image_3.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Chinese girl, delicate skin texture",
|
||||
controlnet_image=image_1.resize((2048, 2048)),
|
||||
input_image=image_1.resize((2048, 2048)), denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=3
|
||||
)
|
||||
image_2.save("image_4.jpg")
|
||||
|
||||
|
||||
def example_3():
|
||||
download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="canny",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="depth",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
image_1 = pipe(
|
||||
prompt="a cat is running",
|
||||
height=1024, width=1024,
|
||||
seed=4
|
||||
)
|
||||
image_1.save("image_5.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="sunshine, a cat is running",
|
||||
controlnet_image=image_1,
|
||||
height=1024, width=1024,
|
||||
seed=5
|
||||
)
|
||||
image_2.save("image_6.jpg")
|
||||
|
||||
|
||||
def example_4():
|
||||
download_models(["FLUX.1-dev", "InstantX/FLUX.1-dev-Controlnet-Union-alpha"])
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="canny",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="depth",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.3
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, summer",
|
||||
height=1024, width=1024,
|
||||
seed=6
|
||||
)
|
||||
image_1.save("image_7.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian girl, full body, red dress, winter",
|
||||
controlnet_image=image_1,
|
||||
height=1024, width=1024,
|
||||
seed=7
|
||||
)
|
||||
image_2.save("image_8.jpg")
|
||||
|
||||
|
||||
|
||||
def example_5():
|
||||
download_models(["FLUX.1-dev", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"])
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="inpaint",
|
||||
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
scale=0.9
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
image_1 = pipe(
|
||||
prompt="a cat sitting on a chair",
|
||||
height=1024, width=1024,
|
||||
seed=8
|
||||
)
|
||||
image_1.save("image_9.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[100:350, 350: -300] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask_9.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a cat sitting on a chair, wearing sunglasses",
|
||||
controlnet_image=image_1, controlnet_inpaint_mask=mask,
|
||||
height=1024, width=1024,
|
||||
seed=9
|
||||
)
|
||||
image_2.save("image_10.jpg")
|
||||
|
||||
|
||||
|
||||
def example_6():
|
||||
download_models([
|
||||
"FLUX.1-dev",
|
||||
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"
|
||||
])
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="inpaint",
|
||||
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
scale=0.9
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="normal",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals/diffusion_pytorch_model.safetensors",
|
||||
scale=0.6
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian woman looking at the sky, wearing a blue t-shirt.",
|
||||
height=1024, width=1024,
|
||||
seed=10
|
||||
)
|
||||
image_1.save("image_11.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[-400:, 10:-40] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("mask_11.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian woman looking at the sky, wearing a yellow t-shirt.",
|
||||
controlnet_image=image_1, controlnet_inpaint_mask=mask,
|
||||
height=1024, width=1024,
|
||||
seed=11
|
||||
)
|
||||
image_2.save("image_12.jpg")
|
||||
|
||||
|
||||
def example_7():
|
||||
download_models([
|
||||
"FLUX.1-dev",
|
||||
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
||||
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
||||
])
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
"models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="inpaint",
|
||||
model_path="models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta/diffusion_pytorch_model.safetensors",
|
||||
scale=0.9
|
||||
),
|
||||
ControlNetConfigUnit(
|
||||
processor_id="canny",
|
||||
model_path="models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha/diffusion_pytorch_model.safetensors",
|
||||
scale=0.5
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
|
||||
height=1024, width=1024,
|
||||
seed=100
|
||||
)
|
||||
image_1.save("image_13.jpg")
|
||||
|
||||
mask_global = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask_global = Image.fromarray(mask_global)
|
||||
mask_global.save("mask_13_global.jpg")
|
||||
|
||||
mask_1 = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask_1[300:-100, 30: 450] = 255
|
||||
mask_1 = Image.fromarray(mask_1)
|
||||
mask_1.save("mask_13_1.jpg")
|
||||
|
||||
mask_2 = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask_2[500:-100, -400:] = 255
|
||||
mask_2[-200:-100, -500:-400] = 255
|
||||
mask_2 = Image.fromarray(mask_2)
|
||||
mask_2.save("mask_13_2.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="a beautiful Asian woman and a cat on a bed. The woman wears a dress.",
|
||||
controlnet_image=image_1, controlnet_inpaint_mask=mask_global,
|
||||
local_prompts=["an orange cat, highly detailed", "a girl wearing a red camisole"], masks=[mask_1, mask_2], mask_scales=[10.0, 10.0],
|
||||
height=1024, width=1024,
|
||||
seed=101
|
||||
)
|
||||
image_2.save("image_14.jpg")
|
||||
|
||||
model_manager.load_lora("models/lora/FLUX-dev-lora-AntiBlur.safetensors", lora_alpha=2)
|
||||
image_3 = pipe(
|
||||
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. clear background.",
|
||||
negative_prompt="blur, blurry",
|
||||
input_image=image_2, denoising_strength=0.7,
|
||||
height=1024, width=1024,
|
||||
cfg_scale=2.0, num_inference_steps=50,
|
||||
seed=102
|
||||
)
|
||||
image_3.save("image_15.jpg")
|
||||
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, controlnet_config_units=[
|
||||
ControlNetConfigUnit(
|
||||
processor_id="tile",
|
||||
model_path="models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler/diffusion_pytorch_model.safetensors",
|
||||
scale=0.7
|
||||
),
|
||||
],device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
for model in pipe.controlnet.models:
|
||||
model.quantize()
|
||||
image_4 = pipe(
|
||||
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
|
||||
controlnet_image=image_3.resize((2048, 2048)),
|
||||
input_image=image_3.resize((2048, 2048)), denoising_strength=0.99,
|
||||
height=2048, width=2048, tiled=True,
|
||||
seed=103
|
||||
)
|
||||
image_4.save("image_16.jpg")
|
||||
|
||||
image_5 = pipe(
|
||||
prompt="a beautiful Asian woman wearing a red camisole and an orange cat on a bed. highly detailed, delicate skin texture, clear background.",
|
||||
controlnet_image=image_4.resize((4096, 4096)),
|
||||
input_image=image_4.resize((4096, 4096)), denoising_strength=0.99,
|
||||
height=4096, width=4096, tiled=True,
|
||||
seed=104
|
||||
)
|
||||
image_5.save("image_17.jpg")
|
||||
|
||||
|
||||
|
||||
download_models(["Annotators:Depth", "Annotators:Normal"])
|
||||
download_customized_models(
|
||||
model_id="LiblibAI/FLUX.1-dev-LoRA-AntiBlur",
|
||||
origin_file_path="FLUX-dev-lora-AntiBlur.safetensors",
|
||||
local_dir="models/lora"
|
||||
)
|
||||
example_1()
|
||||
example_2()
|
||||
example_3()
|
||||
example_4()
|
||||
example_5()
|
||||
example_6()
|
||||
example_7()
|
||||
85
examples/EntityControl/README.md
Normal file
85
examples/EntityControl/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# EliGen: Entity-Level Controlled Image Generation
|
||||
|
||||
## Introduction
|
||||
|
||||
We propose EliGen, a novel approach that leverages fine-grained entity-level information to enable precise and controllable text-to-image generation. EliGen excels in tasks such as entity-level controlled image generation and image inpainting, while its applicability is not limited to these areas. Additionally, it can be seamlessly integrated with existing community models, such as the IP-Adpater and In-Cotext LoRA.
|
||||
|
||||
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
|
||||
## Methodology
|
||||
|
||||

|
||||
|
||||
We introduce a regional attention mechanism within the DiT framework to effectively process the conditions of each entity. This mechanism enables the local prompt associated with each entity to semantically influence specific regions through regional attention. To further enhance the layout control capabilities of EliGen, we meticulously contribute an entity-annotated dataset and fine-tune the model using the LoRA framework.
|
||||
|
||||
1. **Regional Attention**: Regional attention is shown in above figure, which can be easily applied to other text-to-image models. Its core principle involves transforming the positional information of each entity into an attention mask, ensuring that the mechanism only affects the designated regions.
|
||||
|
||||
2. **Dataset with Entity Annotation**: To construct a dedicated entity control dataset, we start by randomly selecting captions from DiffusionDB and generating the corresponding source image using Flux. Next, we employ Qwen2-VL 72B, recognized for its advanced grounding capabilities among MLLMs, to randomly identify entities within the image. These entities are annotated with local prompts and bounding boxes for precise localization, forming the foundation of our dataset for further training.
|
||||
|
||||
3. **Training**: We utilize LoRA (Low-Rank Adaptation) and DeepSpeed to fine-tune regional attention mechanisms using a curated dataset, enabling our EliGen model to achieve effective entity-level control.
|
||||
|
||||
## Usage
|
||||
1. **Entity-Level Controlled Image Generation**
|
||||
EliGen achieves effective entity-level control results. See [./entity_control.py](./entity_control.py) for usage.
|
||||
2. **Image Inpainting**
|
||||
To apply EliGen to image inpainting task, we propose a inpainting fusion pipeline to preserve the non-painting areas while enabling precise, entity-level modifications over inpaining regions.
|
||||
See [./entity_inpaint.py](./entity_inpaint.py) for usage.
|
||||
3. **Styled Entity Control**
|
||||
EliGen can be seamlessly integrated with existing community models. We have provided an example of how to integrate it with the IP-Adpater. See [./entity_control_ipadapter.py](./entity_control_ipadapter.py) for usage.
|
||||
4. **Entity Transfer**
|
||||
We have provided an example of how to integrate EliGen with In-Cotext LoRA, which achieves interesting entity transfer results. See [./entity_transfer.py](./entity_transfer.py) for usage.
|
||||
5. **Play with EliGen using UI**
|
||||
Run the following command to try interactive UI:
|
||||
```bash
|
||||
python apps/gradio/entity_level_control.py
|
||||
```
|
||||
## Examples
|
||||
### Entity-Level Controlled Image Generation
|
||||
|
||||
1. The effect of generating images with continuously changing entity positions.
|
||||
|
||||
https://github.com/user-attachments/assets/54a048c8-b663-4262-8c40-43c87c266d4b
|
||||
|
||||
2. The image generation effect of complex Entity combinations, demonstrating the strong generalization of EliGen. See [./entity_control.py](./entity_control.py) `example_1-6` for generation prompts.
|
||||
|
||||
|Entity Conditions|Generated Image|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|
||||
1. Demonstration of the robustness of EliGen. The following examples are generated using the same prompt but different seeds. Refer to [./entity_control.py](./entity_control.py) `example_7` for the prompts.
|
||||
|
||||
|Entity Conditions|Generated Image|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
||
|
||||
|||
|
||||
|
||||
### Image Inpainting
|
||||
Demonstration of the inpainting mode of EliGen, see [./entity_inpaint.py](./entity_inpaint.py) for generation prompts.
|
||||
|Inpainting Input|Inpainting Output|
|
||||
|-|-|
|
||||
|||
|
||||
|||
|
||||
### Styled Entity Control
|
||||
Demonstration of the styled entity control results with EliGen and IP-Adapter, see [./entity_control_ipadapter.py](./entity_control_ipadapter.py) for generation prompts.
|
||||
|Style Reference|Entity Control Variance 1|Entity Control Variance 2|Entity Control Variance 3|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
|
||||
### Entity Transfer
|
||||
Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts.
|
||||
|
||||
|Entity to Transfer|Transfer Target Image|Transfer Example 1|Transfer Example 2|
|
||||
|-|-|-|-|
|
||||
|||||
|
||||
74
examples/EntityControl/entity_control.py
Normal file
74
examples/EntityControl/entity_control.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
from examples.EntityControl.utils import visualize_masks
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
def example(pipe, seeds, example_id, global_prompt, entity_prompts):
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png")
|
||||
masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
for seed in seeds:
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
)
|
||||
image.save(f"eligen_example_{example_id}_{seed}.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png")
|
||||
|
||||
# download and load model
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# example 1
|
||||
global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n"
|
||||
entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"]
|
||||
example(pipe, [0], 1, global_prompt, entity_prompts)
|
||||
|
||||
# example 2
|
||||
global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render."
|
||||
entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"]
|
||||
example(pipe, [0], 2, global_prompt, entity_prompts)
|
||||
|
||||
# example 3
|
||||
global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning,"
|
||||
entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"]
|
||||
example(pipe, [27], 3, global_prompt, entity_prompts)
|
||||
|
||||
# example 4
|
||||
global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'"
|
||||
entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"]
|
||||
example(pipe, [21], 4, global_prompt, entity_prompts)
|
||||
|
||||
# example 5
|
||||
global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere."
|
||||
entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"]
|
||||
example(pipe, [0], 5, global_prompt, entity_prompts)
|
||||
|
||||
# example 6
|
||||
global_prompt = "Snow White and the 6 Dwarfs."
|
||||
entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"]
|
||||
example(pipe, [8], 6, global_prompt, entity_prompts)
|
||||
|
||||
# example 7, same prompt with different seeds
|
||||
seeds = range(5, 9)
|
||||
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
|
||||
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
|
||||
example(pipe, seeds, 7, global_prompt, entity_prompts)
|
||||
46
examples/EntityControl/entity_control_ipadapter.py
Normal file
46
examples/EntityControl/entity_control_ipadapter.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
from examples.EntityControl.utils import visualize_masks
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
# download and load model
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# download and load mask images
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/ipadapter/*")
|
||||
masks = [Image.open(f"./data/examples/eligen/ipadapter/ipadapter_mask_{i}.png") for i in range(1, 4)]
|
||||
|
||||
entity_prompts = ['A girl', 'hat', 'sunset']
|
||||
global_prompt = "A girl wearing a hat, looking at the sunset"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw"
|
||||
reference_img = Image.open("./data/examples/eligen/ipadapter/ipadapter_image.png")
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=4,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
enable_eligen_on_negative=False,
|
||||
ipadapter_images=[reference_img],
|
||||
ipadapter_scale=0.7
|
||||
)
|
||||
image.save(f"styled_entity_control.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png")
|
||||
45
examples/EntityControl/entity_inpaint.py
Normal file
45
examples/EntityControl/entity_inpaint.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
from examples.EntityControl.utils import visualize_masks
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# download and load model
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# download and load mask images
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/inpaint/*")
|
||||
masks = [Image.open(f"./data/examples/eligen/inpaint/inpaint_mask_{i}.png") for i in range(1, 3)]
|
||||
input_image = Image.open("./data/examples/eligen/inpaint/inpaint_image.jpg")
|
||||
|
||||
entity_prompts = ["A person wear red shirt", "Airplane"]
|
||||
global_prompt = "A person walking on the path in front of a house; An airplane in the sky"
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur"
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt=global_prompt,
|
||||
input_image=input_image,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
eligen_entity_prompts=entity_prompts,
|
||||
eligen_entity_masks=masks,
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=True,
|
||||
)
|
||||
image.save(f"entity_inpaint.png")
|
||||
visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png")
|
||||
84
examples/EntityControl/entity_transfer.py
Normal file
84
examples/EntityControl/entity_transfer.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||
from modelscope import dataset_snapshot_download
|
||||
from examples.EntityControl.utils import visualize_masks
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
def build_pipeline():
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="DiffSynth-Studio/Eligen",
|
||||
origin_file_path="model_bf16.safetensors",
|
||||
local_dir="models/lora/entity_control"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
model_manager.load_lora(
|
||||
download_customized_models(
|
||||
model_id="iic/In-Context-LoRA",
|
||||
origin_file_path="visual-identity-design.safetensors",
|
||||
local_dir="models/lora/In-Context-LoRA"
|
||||
),
|
||||
lora_alpha=1
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def generate(pipe: FluxImagePipeline, source_image, target_image, mask, height, width, prompt, entity_prompt, image_save_path, mask_save_path, seed=0):
|
||||
input_mask = Image.new('RGB', (width * 2, height))
|
||||
input_mask.paste(mask.resize((width, height), resample=Image.NEAREST).convert('RGB'), (width, 0))
|
||||
|
||||
input_image = Image.new('RGB', (width * 2, height))
|
||||
input_image.paste(source_image.resize((width, height)).convert('RGB'), (0, 0))
|
||||
input_image.paste(target_image.resize((width, height)).convert('RGB'), (width, 0))
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
input_image=input_image,
|
||||
cfg_scale=3.0,
|
||||
negative_prompt="",
|
||||
num_inference_steps=50,
|
||||
embedded_guidance=3.5,
|
||||
seed=seed,
|
||||
height=height,
|
||||
width=width * 2,
|
||||
eligen_entity_prompts=[entity_prompt],
|
||||
eligen_entity_masks=[input_mask],
|
||||
enable_eligen_on_negative=False,
|
||||
enable_eligen_inpaint=True,
|
||||
)
|
||||
target_image = image.crop((width, 0, 2 * width, height))
|
||||
target_image.save(image_save_path)
|
||||
visualize_masks(target_image, [mask], [entity_prompt], mask_save_path)
|
||||
return target_image
|
||||
|
||||
|
||||
pipe = build_pipeline()
|
||||
|
||||
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/logo_transfer/*")
|
||||
|
||||
prompt="The two-panel image showcases the joyful identity, with the left panel showing a rabbit graphic; [LEFT] while the right panel translates the design onto a shopping tote with the rabbit logo in black, held by a person in a market setting, emphasizing the brand's approachable and eco-friendly vibe."
|
||||
logo_prompt="a rabbit logo"
|
||||
|
||||
logo_image = Image.open("data/examples/eligen/logo_transfer/source_image.png")
|
||||
target_image = Image.open("data/examples/eligen/logo_transfer/target_image.png")
|
||||
mask = Image.open("data/examples/eligen/logo_transfer/mask_1.png")
|
||||
generate(
|
||||
pipe, logo_image, target_image, mask,
|
||||
height=1024, width=1024,
|
||||
prompt=prompt, entity_prompt=logo_prompt,
|
||||
image_save_path="entity_transfer_1.png",
|
||||
mask_save_path="entity_transfer_with_mask_1.png"
|
||||
)
|
||||
|
||||
mask = Image.open("data/examples/eligen/logo_transfer/mask_2.png")
|
||||
generate(
|
||||
pipe, logo_image, target_image, mask,
|
||||
height=1024, width=1024,
|
||||
prompt=prompt, entity_prompt=logo_prompt,
|
||||
image_save_path="entity_transfer_2.png",
|
||||
mask_save_path="entity_transfer_with_mask_2.png"
|
||||
)
|
||||
59
examples/EntityControl/utils.py
Normal file
59
examples/EntityControl/utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import random
|
||||
|
||||
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
|
||||
# Create a blank image for overlays
|
||||
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
|
||||
colors = [
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
(165, 238, 173, 80),
|
||||
(76, 102, 221, 80),
|
||||
(221, 160, 77, 80),
|
||||
(204, 93, 71, 80),
|
||||
(145, 187, 149, 80),
|
||||
(134, 141, 172, 80),
|
||||
(157, 137, 109, 80),
|
||||
(153, 104, 95, 80),
|
||||
]
|
||||
# Generate random colors for each mask
|
||||
if use_random_colors:
|
||||
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
||||
|
||||
# Font settings
|
||||
try:
|
||||
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
||||
except IOError:
|
||||
font = ImageFont.load_default(font_size)
|
||||
|
||||
# Overlay each mask onto the overlay image
|
||||
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
||||
# Convert mask to RGBA mode
|
||||
mask_rgba = mask.convert('RGBA')
|
||||
mask_data = mask_rgba.getdata()
|
||||
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
||||
mask_rgba.putdata(new_data)
|
||||
|
||||
# Draw the mask prompt text on the mask
|
||||
draw = ImageDraw.Draw(mask_rgba)
|
||||
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
||||
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
||||
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
||||
|
||||
# Alpha composite the overlay with this mask
|
||||
overlay = Image.alpha_composite(overlay, mask_rgba)
|
||||
|
||||
# Composite the overlay onto the original image
|
||||
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
||||
|
||||
# Save or display the resulting image
|
||||
result.save(output_path)
|
||||
|
||||
return result
|
||||
21
examples/ExVideo/ExVideo_cogvideox_test.py
Normal file
21
examples/ExVideo/ExVideo_cogvideox_test.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from diffsynth import ModelManager, CogVideoPipeline, save_video, download_models
|
||||
import torch
|
||||
|
||||
|
||||
download_models(["CogVideoX-5B", "ExVideo-CogVideoX-LoRA-129f-v1"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16)
|
||||
model_manager.load_models([
|
||||
"models/CogVideo/CogVideoX-5b/text_encoder",
|
||||
"models/CogVideo/CogVideoX-5b/transformer",
|
||||
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
||||
])
|
||||
model_manager.load_lora("models/lora/ExVideo-CogVideoX-LoRA-129f-v1.safetensors")
|
||||
pipe = CogVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
torch.manual_seed(6)
|
||||
video = pipe(
|
||||
prompt="an astronaut riding a horse on Mars.",
|
||||
height=480, width=720, num_frames=129,
|
||||
cfg_scale=7.0, num_inference_steps=100,
|
||||
)
|
||||
save_video(video, "video_with_lora.mp4", fps=8, quality=5)
|
||||
@@ -4,11 +4,19 @@ ExVideo is a post-tuning technique aimed at enhancing the capability of video ge
|
||||
|
||||
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
||||
* [Technical report](https://arxiv.org/abs/2406.14130)
|
||||
* [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)
|
||||
* Extended models
|
||||
* **[New]** Extended models (ExVideo-CogVideoX)
|
||||
* [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1)
|
||||
* [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1)
|
||||
* Extended models (ExVideo-SVD)
|
||||
* [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
|
||||
## Example: Text-to-video via extended CogVideoX-5B
|
||||
|
||||
Generate a video using CogVideoX-5B and our extension module. See [ExVideo_cogvideox_test.py](./ExVideo_cogvideox_test.py).
|
||||
|
||||
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
|
||||
|
||||
## Example: Text-to-video via extended Stable Video Diffusion
|
||||
|
||||
Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd_test.py](./ExVideo_svd_test.py).
|
||||
|
||||
23
examples/HunyuanVideo/README.md
Normal file
23
examples/HunyuanVideo/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# HunyuanVideo
|
||||
|
||||
[HunyuanVideo](https://github.com/Tencent/HunyuanVideo) is a video generation model trained by Tencent. We provide advanced VRAM management for this model, including three stages:
|
||||
|
||||
|VRAM required|Example script|Frames|Resolution|Note|
|
||||
|-|-|-|-|-|
|
||||
|80G|[hunyuanvideo_80G.py](hunyuanvideo_80G.py)|129|720*1280|No VRAM management.|
|
||||
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
||||
|
||||
## Gallery
|
||||
|
||||
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
||||
|
||||
https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
|
||||
|
||||
Video generated by [hunyuanvideo_6G.py](hunyuanvideo_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||
|
||||
https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
||||
|
||||
Video to video generated by [hunyuanvideo_v2v_6G.py](./hunyuanvideo_v2v_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||
|
||||
https://github.com/user-attachments/assets/4b89e52e-ce42-434e-aa57-08f09dfa2b10
|
||||
42
examples/HunyuanVideo/hunyuanvideo_24G.py
Normal file
42
examples/HunyuanVideo/hunyuanvideo_24G.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Enjoy!
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=0)
|
||||
save_video(video, "video_girl.mp4", fps=30, quality=6)
|
||||
52
examples/HunyuanVideo/hunyuanvideo_6G.py
Normal file
52
examples/HunyuanVideo/hunyuanvideo_6G.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video, FlowMatchScheduler, download_customized_models
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# Example LoRA: https://civitai.com/models/1032126/walking-animation-hunyuan-video
|
||||
download_customized_models(
|
||||
model_id="AI-ModelScope/walking_animation_hunyuan_video",
|
||||
origin_file_path="kxsr_walking_anim_v1-5.safetensors",
|
||||
local_dir="models/lora"
|
||||
)[0]
|
||||
model_manager.load_lora("models/lora/kxsr_walking_anim_v1-5.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
# This LoRA requires shift=9.0.
|
||||
pipe.scheduler = FlowMatchScheduler(shift=9.0, sigma_min=0.0, extra_one_step=True)
|
||||
|
||||
# Enjoy!
|
||||
for clothes_up in ["white t-shirt", "black t-shirt", "orange t-shirt"]:
|
||||
for clothes_down in ["blue sports skirt", "red sports skirt", "white sports skirt"]:
|
||||
prompt = f"kxsr, full body, no crop, A 3D-rendered CG animation video featuring a Gorgeous, mature, curvaceous, fair-skinned female girl with long silver hair and blue eyes. She wears a {clothes_up} and a {clothes_down}, walking offering a sense of fluid movement and vivid animation."
|
||||
video = pipe(prompt, seed=0, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12))
|
||||
save_video(video, f"video-{clothes_up}-{clothes_down}.mp4", fps=30, quality=6)
|
||||
45
examples/HunyuanVideo/hunyuanvideo_80G.py
Normal file
45
examples/HunyuanVideo/hunyuanvideo_80G.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=False
|
||||
)
|
||||
# Although you have enough VRAM, we still recommend you to enable offload.
|
||||
pipe.enable_cpu_offload()
|
||||
|
||||
# Enjoy!
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=0)
|
||||
save_video(video, "video.mp4", fps=30, quality=6)
|
||||
55
examples/HunyuanVideo/hunyuanvideo_v2v_6G.py
Normal file
55
examples/HunyuanVideo/hunyuanvideo_v2v_6G.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video, FlowMatchScheduler, download_customized_models
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# Example LoRA: https://civitai.com/models/1032126/walking-animation-hunyuan-video
|
||||
download_customized_models(
|
||||
model_id="AI-ModelScope/walking_animation_hunyuan_video",
|
||||
origin_file_path="kxsr_walking_anim_v1-5.safetensors",
|
||||
local_dir="models/lora"
|
||||
)[0]
|
||||
model_manager.load_lora("models/lora/kxsr_walking_anim_v1-5.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
# This LoRA requires shift=9.0.
|
||||
pipe.scheduler = FlowMatchScheduler(shift=9.0, sigma_min=0.0, extra_one_step=True)
|
||||
|
||||
# Text-to-video
|
||||
prompt = f"kxsr, full body, no crop. A girl is walking. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=1, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12))
|
||||
save_video(video, f"video.mp4", fps=30, quality=6)
|
||||
|
||||
# Video-to-video
|
||||
prompt = f"kxsr, full body, no crop. A girl is walking. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, purple dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=1, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12), input_video=video, denoising_strength=0.85)
|
||||
save_video(video, f"video_edited.mp4", fps=30, quality=6)
|
||||
38
examples/Ip-Adapter/flux_ipadapter.py
Normal file
38
examples/Ip-Adapter/flux_ipadapter.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from diffsynth import ModelManager, download_models, FluxImagePipeline
|
||||
import torch
|
||||
|
||||
# Download models (automatically)
|
||||
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin`: [link](https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/blob/main/ip-adapter.bin)
|
||||
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)
|
||||
download_models(["InstantX/FLUX.1-dev-IP-Adapter", "FLUX.1-dev"])
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
|
||||
])
|
||||
seed = 42
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
torch.manual_seed(seed)
|
||||
origin_prompt = "a rabbit in a garden, colorful flowers"
|
||||
image = pipe(
|
||||
prompt=origin_prompt,
|
||||
cfg_scale=1.0, embedded_guidance=3.5,
|
||||
height=1280, width=960, num_inference_steps=30
|
||||
)
|
||||
image.save("style image.jpg")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image = pipe(
|
||||
prompt="A piggy",
|
||||
cfg_scale=1.0, embedded_guidance=3.5,
|
||||
height=1280, width=960, num_inference_steps=30,
|
||||
ipadapter_images=[image], ipadapter_scale=0.7
|
||||
)
|
||||
image.save("A piggy.jpg")
|
||||
|
||||
34
examples/TeaCache/README.md
Normal file
34
examples/TeaCache/README.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# TeaCache
|
||||
|
||||
TeaCache ([Timestep Embedding Aware Cache](https://github.com/ali-vilab/TeaCache)) is a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
|
||||
|
||||
## Examples
|
||||
|
||||
### FLUX
|
||||
|
||||
Script: [./flux_teacache.py](./flux_teacache.py)
|
||||
|
||||
Model: FLUX.1-dev
|
||||
|
||||
Steps: 50
|
||||
|
||||
GPU: A100
|
||||
|
||||
|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.8|
|
||||
|-|-|-|
|
||||
|23s|13s|5s|
|
||||
|||
|
||||
|
||||
### Hunyuan Video
|
||||
|
||||
Script: [./hunyuanvideo_teacache.py](./hunyuanvideo_teacache.py)
|
||||
|
||||
Model: Hunyuan Video
|
||||
|
||||
Steps: 30
|
||||
|
||||
GPU: A100
|
||||
|
||||
The following video was generated using TeaCache. It is nearly identical to [the video without TeaCache enabled](https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9), but with double the speed.
|
||||
|
||||
https://github.com/user-attachments/assets/cd9801c5-88ce-4efc-b055-2c7737166f34
|
||||
15
examples/TeaCache/flux_teacache.py
Normal file
15
examples/TeaCache/flux_teacache.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
|
||||
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
|
||||
image = pipe(
|
||||
prompt=prompt, embedded_guidance=3.5, seed=0,
|
||||
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
|
||||
)
|
||||
image.save(f"image_{tea_cache_l1_thresh}.png")
|
||||
42
examples/TeaCache/hunyuanvideo_teacache.py
Normal file
42
examples/TeaCache/hunyuanvideo_teacache.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Enjoy!
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=0, tea_cache_l1_thresh=0.15)
|
||||
save_video(video, "video_girl.mp4", fps=30, quality=6)
|
||||
18
examples/WanX/test_prompter.py
Normal file
18
examples/WanX/test_prompter.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from diffsynth.prompters import WanXPrompter
|
||||
from diffsynth.models.wanx_text_encoder import WanXTextEncoder
|
||||
|
||||
prompter = WanXPrompter('models/WanX/google/umt5-xxl')
|
||||
text_encoder = WanXTextEncoder()
|
||||
text_encoder.load_state_dict(torch.load('models/WanX/models_t5_umt5-xxl-enc-bf16.pth', map_location='cpu'))
|
||||
text_encoder = text_encoder.eval().requires_grad_(False).to(dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
prompter.fetch_models(text_encoder)
|
||||
|
||||
prompt = '维京战士双手挥舞着大斧,对抗猛犸象,黄昏,雪地中,漫天飞雪'
|
||||
neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
||||
|
||||
prompt_emb = prompter.encode_prompt(prompt)
|
||||
neg_prompt_emb = prompter.encode_prompt(neg_prompt)
|
||||
print(prompt_emb[0]) # torch.Size([31, 4096])
|
||||
print(neg_prompt_emb[0]) # torch.Size([126, 4096])
|
||||
46
examples/WanX/test_vae.py
Normal file
46
examples/WanX/test_vae.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import imageio
|
||||
from diffsynth import ModelManager
|
||||
|
||||
def save_video(tensor,
|
||||
save_file=None,
|
||||
fps=30,
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1)):
|
||||
|
||||
tensor = tensor.clamp(min(value_range), max(value_range))
|
||||
tensor = torch.stack([
|
||||
torchvision.utils.make_grid(
|
||||
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
for u in tensor.unbind(2)
|
||||
],
|
||||
dim=1).permute(1, 2, 3, 0) #frame, h, w, 3
|
||||
tensor = (tensor * 255).type(torch.uint8).cpu()
|
||||
|
||||
# write video
|
||||
writer = imageio.get_writer(
|
||||
save_file, fps=fps, codec='libx264', quality=8)
|
||||
for frame in tensor.numpy():
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
||||
torch.cuda.memory._record_memory_history()
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.float, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/WanX/vae.pth",
|
||||
])
|
||||
|
||||
vae = model_manager.fetch_model('wanxvideo_vae')
|
||||
|
||||
latents = [torch.load('sample.pt')]
|
||||
videos = vae.decode(latents, device=latents[0].device, tiled=True)
|
||||
back_encode = vae.encode(videos, device=latents[0].device, tiled=True)
|
||||
|
||||
videos_back_encode = vae.decode(back_encode, device=latents[0].device, tiled=False)
|
||||
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
|
||||
|
||||
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
|
||||
save_video(videos_back_encode[0][None], save_file='example_backencode.mp4', fps=16, nrow=1)
|
||||
@@ -2,15 +2,23 @@
|
||||
|
||||
Image synthesis is the base feature of DiffSynth Studio. We can generate images with very high resolution.
|
||||
|
||||
### OmniGen
|
||||
|
||||
OmniGen is a text-image-to-image model, you can synthesize an image according to several given reference images.
|
||||
|
||||
|Reference image 1|Reference image 2|Synthesized image|
|
||||
|-|-|-|
|
||||
||||
|
||||
|
||||
### Example: FLUX
|
||||
|
||||
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py)
|
||||
Example script: [`flux_text_to_image.py`](./flux_text_to_image.py) and [`flux_text_to_image_low_vram.py`](./flux_text_to_image_low_vram.py)(low VRAM).
|
||||
|
||||
The original version of FLUX doesn't support classifier-free guidance; however, we believe that this guidance mechanism is an important feature for synthesizing beautiful images. You can enable it using the parameter `cfg_scale`, and the extra guidance scale introduced by FLUX is `embedded_guidance`.
|
||||
|
||||
|1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)|
|
||||
|-|-|-|
|
||||
||||
|
||||
||||
|
||||
|
||||
### Example: Stable Diffusion
|
||||
|
||||
|
||||
@@ -12,30 +12,30 @@ model_manager.load_models([
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
|
||||
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
||||
torch.manual_seed(6)
|
||||
torch.manual_seed(9)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5
|
||||
num_inference_steps=50, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
|
||||
# Enable classifier-free guidance
|
||||
torch.manual_seed(6)
|
||||
torch.manual_seed(9)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
|
||||
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024_cfg.jpg")
|
||||
|
||||
# Highres-fix
|
||||
torch.manual_seed(7)
|
||||
torch.manual_seed(10)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
||||
)
|
||||
image.save("image_2048_highres.jpg")
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline, download_models
|
||||
|
||||
|
||||
download_models(["FLUX.1-dev"])
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||
])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, device='cuda')
|
||||
pipe.enable_cpu_offload()
|
||||
|
||||
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin."
|
||||
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw,"
|
||||
|
||||
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
||||
torch.manual_seed(6)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
|
||||
# Enable classifier-free guidance
|
||||
torch.manual_seed(6)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024_cfg.jpg")
|
||||
|
||||
# Highres-fix
|
||||
torch.manual_seed(7)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=30, embedded_guidance=3.5,
|
||||
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
||||
)
|
||||
image.save("image_2048_highres.jpg")
|
||||
51
examples/image_synthesis/flux_text_to_image_low_vram.py
Normal file
51
examples/image_synthesis/flux_text_to_image_low_vram.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from diffsynth import download_models, ModelManager, FluxImagePipeline
|
||||
|
||||
|
||||
download_models(["FLUX.1-dev"])
|
||||
|
||||
model_manager = ModelManager(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu" # To reduce VRAM required, we load models to RAM.
|
||||
)
|
||||
model_manager.load_models([
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
])
|
||||
model_manager.load_models(
|
||||
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
|
||||
torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format.
|
||||
)
|
||||
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda")
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.quantize()
|
||||
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
|
||||
torch.manual_seed(9)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=50, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
|
||||
# Enable classifier-free guidance
|
||||
torch.manual_seed(9)
|
||||
image = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt,
|
||||
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
|
||||
)
|
||||
image.save("image_1024_cfg.jpg")
|
||||
|
||||
# Highres-fix
|
||||
torch.manual_seed(10)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
|
||||
)
|
||||
image.save("image_2048_highres.jpg")
|
||||
25
examples/image_synthesis/omnigen_text_to_image.py
Normal file
25
examples/image_synthesis/omnigen_text_to_image.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, OmnigenImagePipeline
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, model_id_list=["OmniGen-v1"])
|
||||
pipe = OmnigenImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
image_man = pipe(
|
||||
prompt="A portrait of a man.",
|
||||
cfg_scale=2.5, num_inference_steps=50, seed=0
|
||||
)
|
||||
image_man.save("image_man.jpg")
|
||||
|
||||
image_woman = pipe(
|
||||
prompt="A portrait of an Asian woman with a white t-shirt.",
|
||||
cfg_scale=2.5, num_inference_steps=50, seed=1
|
||||
)
|
||||
image_woman.save("image_woman.jpg")
|
||||
|
||||
image_merged = pipe(
|
||||
prompt="a man and a woman. The man is the man in <img><|image_1|></img>. The woman is the woman in <img><|image_2|></img>.",
|
||||
reference_images=[image_man, image_woman],
|
||||
cfg_scale=2.5, image_cfg_scale=2.5, num_inference_steps=50, seed=2
|
||||
)
|
||||
image_merged.save("image_merged.jpg")
|
||||
28
examples/image_synthesis/sd35_text_to_image.py
Normal file
28
examples/image_synthesis/sd35_text_to_image.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from diffsynth import ModelManager, SD3ImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", model_id_list=["StableDiffusion3.5-large"])
|
||||
pipe = SD3ImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "a full body photo of a beautiful Asian girl. CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
|
||||
|
||||
torch.manual_seed(1)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=5,
|
||||
num_inference_steps=100, width=1024, height=1024,
|
||||
)
|
||||
image.save("image_1024.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=5,
|
||||
input_image=image.resize((2048, 2048)), denoising_strength=0.5,
|
||||
num_inference_steps=50, width=2048, height=2048,
|
||||
tiled=True
|
||||
)
|
||||
image.save("image_2048.jpg")
|
||||
19
examples/stepvideo/README.md
Normal file
19
examples/stepvideo/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Stepvideo
|
||||
|
||||
StepVideo is a state-of-the-art (SoTA) text-to-video pre-trained model with 30 billion parameters and the capability to generate videos up to 204 frames.
|
||||
|
||||
* Model: https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary
|
||||
* GitHub: https://github.com/stepfun-ai/Step-Video-T2V
|
||||
* Technical report: https://arxiv.org/abs/2502.10248
|
||||
|
||||
## Examples
|
||||
|
||||
For original BF16 version, please see [`./stepvideo_text_to_video.py`](./stepvideo_text_to_video.py). 80G VRAM required.
|
||||
|
||||
We also support auto-offload, which can reduce the VRAM requirement to **24GB**; however, it requires 2x time for inference. Please see [`./stepvideo_text_to_video_low_vram.py`](./stepvideo_text_to_video_low_vram.py).
|
||||
|
||||
https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
|
||||
|
||||
For FP8 quantized version, please see [`./stepvideo_text_to_video_quantized.py`](./stepvideo_text_to_video_quantized.py). 40G VRAM required.
|
||||
|
||||
https://github.com/user-attachments/assets/f3697f4e-bc08-47d2-b00a-32d7dfa272ad
|
||||
50
examples/stepvideo/stepvideo_text_to_video.py
Normal file
50
examples/stepvideo/stepvideo_text_to_video.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from modelscope import snapshot_download
|
||||
from diffsynth import ModelManager, StepVideoPipeline, save_video
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download(model_id="stepfun-ai/stepvideo-t2v", cache_dir="models")
|
||||
|
||||
# Load the compiled attention for the LLM text encoder.
|
||||
# If you encounter errors here. Please select other compiled file that matches your environment or delete this line.
|
||||
torch.ops.load_library("models/stepfun-ai/stepvideo-t2v/lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/hunyuan_clip/clip_text_encoder/pytorch_model.bin"],
|
||||
torch_dtype=torch.float32, device="cpu"
|
||||
)
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/step_llm",
|
||||
"models/stepfun-ai/stepvideo-t2v/vae/vae_v2.safetensors",
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
]
|
||||
],
|
||||
torch_dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Enable VRAM management
|
||||
# This model requires 80G VRAM.
|
||||
# In order to reduce VRAM required, please set `num_persistent_param_in_dit` to a small number.
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Run!
|
||||
video = pipe(
|
||||
prompt="一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光。超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。",
|
||||
negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
|
||||
num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1
|
||||
)
|
||||
save_video(
|
||||
video, "video.mp4", fps=25, quality=5,
|
||||
ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"]
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user